Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
287ca7d5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
287ca7d5
编写于
11月 15, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
11月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MLPerf Optimization for Release/2.2 (#37109)
* add mlperf optimization PRs * update
上级
70cb0a54
变更
68
隐藏空白更改
内联
并排
Showing
68 changed file
with
6480 addition
and
561 deletion
+6480
-561
cmake/operators.cmake
cmake/operators.cmake
+1
-1
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
...uid/framework/details/fast_threaded_ssa_graph_executor.cc
+5
-3
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+14
-5
paddle/fluid/framework/details/scale_loss_grad_op_handle.h
paddle/fluid/framework/details/scale_loss_grad_op_handle.h
+6
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+31
-22
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+1
-1
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
...ramework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
+6
-3
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
...luid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc
...evices_graph_pass/modify_op_lock_and_record_event_pass.cc
+12
-2
paddle/fluid/framework/operator_kernel_configs.h
paddle/fluid/framework/operator_kernel_configs.h
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+143
-0
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+2
-0
paddle/fluid/memory/allocation/CMakeLists.txt
paddle/fluid/memory/allocation/CMakeLists.txt
+5
-1
paddle/fluid/memory/allocation/allocator_facade.cc
paddle/fluid/memory/allocation/allocator_facade.cc
+137
-10
paddle/fluid/memory/allocation/allocator_facade.h
paddle/fluid/memory/allocation/allocator_facade.h
+8
-0
paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc
...fluid/memory/allocation/auto_growth_best_fit_allocator.cc
+6
-2
paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h
.../fluid/memory/allocation/auto_growth_best_fit_allocator.h
+2
-1
paddle/fluid/operators/conv_cudnn_helper.h
paddle/fluid/operators/conv_cudnn_helper.h
+3
-0
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+9
-1
paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
+784
-0
paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h
paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h
+193
-0
paddle/fluid/operators/fused/cudnn_fusion_helper.h
paddle/fluid/operators/fused/cudnn_fusion_helper.h
+167
-0
paddle/fluid/operators/fused/cudnn_norm_conv.cu.h
paddle/fluid/operators/fused/cudnn_norm_conv.cu.h
+385
-0
paddle/fluid/operators/fused/cudnn_norm_conv_test.cc
paddle/fluid/operators/fused/cudnn_norm_conv_test.cc
+458
-0
paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
+317
-0
paddle/fluid/operators/fused/resnet_unit_op.cc
paddle/fluid/operators/fused/resnet_unit_op.cc
+411
-0
paddle/fluid/operators/fused/resnet_unit_op.cu
paddle/fluid/operators/fused/resnet_unit_op.cu
+299
-0
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+333
-234
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+6
-4
paddle/fluid/operators/optimizers/lars_momentum_op.cc
paddle/fluid/operators/optimizers/lars_momentum_op.cc
+126
-18
paddle/fluid/operators/optimizers/lars_momentum_op.cu
paddle/fluid/operators/optimizers/lars_momentum_op.cu
+462
-100
paddle/fluid/operators/optimizers/lars_momentum_op.h
paddle/fluid/operators/optimizers/lars_momentum_op.h
+35
-41
paddle/fluid/operators/optimizers/merged_momentum_op.cc
paddle/fluid/operators/optimizers/merged_momentum_op.cc
+95
-0
paddle/fluid/operators/optimizers/merged_momentum_op.cu
paddle/fluid/operators/optimizers/merged_momentum_op.cu
+24
-0
paddle/fluid/operators/optimizers/merged_momentum_op.h
paddle/fluid/operators/optimizers/merged_momentum_op.h
+197
-0
paddle/fluid/operators/optimizers/momentum_op.h
paddle/fluid/operators/optimizers/momentum_op.h
+35
-32
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc
.../operators/optimizers/pow2_decay_with_linear_warmup_op.cc
+88
-0
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu
.../operators/optimizers/pow2_decay_with_linear_warmup_op.cu
+24
-0
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h
...d/operators/optimizers/pow2_decay_with_linear_warmup_op.h
+115
-0
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+5
-0
paddle/fluid/platform/cuda_graph.cc
paddle/fluid/platform/cuda_graph.cc
+104
-0
paddle/fluid/platform/cuda_graph.h
paddle/fluid/platform/cuda_graph.h
+142
-0
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
+48
-0
paddle/fluid/platform/cuda_graph_with_memory_pool.h
paddle/fluid/platform/cuda_graph_with_memory_pool.h
+64
-0
paddle/fluid/platform/cudnn_desc.h
paddle/fluid/platform/cudnn_desc.h
+24
-9
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+12
-1
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+2
-0
paddle/fluid/platform/macros.h
paddle/fluid/platform/macros.h
+6
-0
paddle/fluid/platform/type_defs.h
paddle/fluid/platform/type_defs.h
+1
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+53
-1
python/paddle/device/cuda/graphs.py
python/paddle/device/cuda/graphs.py
+57
-0
python/paddle/fluid/contrib/layers/nn.py
python/paddle/fluid/contrib/layers/nn.py
+35
-0
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+6
-14
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+9
-3
python/paddle/fluid/memory_analysis.py
python/paddle/fluid/memory_analysis.py
+77
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+3
-2
python/paddle/fluid/tests/unittests/test_cuda_graph.py
python/paddle/fluid/tests/unittests/test_cuda_graph.py
+147
-0
python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py
...e/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py
+1
-1
python/paddle/fluid/tests/unittests/test_memory_analysis.py
python/paddle/fluid/tests/unittests/test_memory_analysis.py
+52
-0
python/paddle/fluid/tests/unittests/test_merged_momentum_op.py
...n/paddle/fluid/tests/unittests/test_merged_momentum_op.py
+197
-0
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+86
-47
python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py
.../tests/unittests/test_pow2_decay_with_linear_warmup_op.py
+88
-0
python/paddle/fluid/tests/unittests/test_tensor_copy_from.py
python/paddle/fluid/tests/unittests/test_tensor_copy_from.py
+39
-0
python/paddle/incubate/operators/__init__.py
python/paddle/incubate/operators/__init__.py
+1
-0
python/paddle/incubate/operators/resnet_unit.py
python/paddle/incubate/operators/resnet_unit.py
+269
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
287ca7d5
...
...
@@ -218,7 +218,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"sync_batch_norm_op"
"sparse_attention_op"
"dgc_op"
"fused_fc_elementwise_layernorm_op"
"skip_layernorm_op"
"multihead_matmul_op"
"fusion_group_op"
"fused_bn_activation_op"
"fused_embedding_eltwise_layernorm_op"
"fusion_gru_op"
"fusion_lstm_op"
"fused_bn_add_activation_op"
"fused_attention_op"
"fused_feedforward_op"
)
"fused_bn_add_activation_op"
"fused_attention_op"
"fused_feedforward_op"
"resnet_unit_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
endif
()
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
287ca7d5
...
...
@@ -143,6 +143,8 @@ struct BuildStrategy {
// Turn off inplace addto by default.
bool
enable_addto_
{
false
};
bool
allow_cuda_graph_capture_
{
false
};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
...
...
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
浏览文件 @
287ca7d5
...
...
@@ -130,10 +130,12 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
}
}
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
if
(
!
fetch_ops
.
empty
())
{
ClearFetchOp
(
graph_
,
&
fetch_ops
);
for
(
auto
&
place
:
places_
)
{
fetch_ctxs_
.
Get
(
place
)
->
Wait
();
for
(
auto
&
place
:
places_
)
{
fetch_ctxs_
.
Get
(
place
)
->
Wait
();
}
}
return
fetches
;
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
287ca7d5
...
...
@@ -86,19 +86,28 @@ struct ScaleLossGradFunctor {
}
};
std
::
string
ScaleLossGradOpHandle
::
LossGradName
()
const
{
return
static_cast
<
VarHandle
*>
(
this
->
outputs_
[
0
])
->
name
();
}
void
ScaleLossGradOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
record_event
(
Name
());
// Doesn't wait any event
std
::
string
var_name
=
static_cast
<
VarHandle
*>
(
this
->
outputs_
[
0
])
->
name
();
RunOnVar
(
local_exec_scopes_
[
0
]
->
FindVar
(
LossGradName
()),
true
);
}
auto
*
tensor
=
local_exec_scopes_
[
0
]
->
FindVar
(
var_name
)
->
GetMutable
<
LoDTensor
>
();
void
ScaleLossGradOpHandle
::
RunOnVar
(
Variable
*
var
,
bool
record_event
)
{
auto
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
make_ddim
({
1
}));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ScaleLossGradFunctor
func
(
coeff_
,
tensor
,
place_
,
out_dtype_
,
this
->
dev_ctxes_
.
at
(
place_
));
this
->
RunAndRecordEvent
([
&
]
{
framework
::
VisitDataType
(
out_dtype_
,
func
);
});
if
(
record_event
)
{
this
->
RunAndRecordEvent
(
[
&
]
{
framework
::
VisitDataType
(
out_dtype_
,
func
);
});
}
else
{
framework
::
VisitDataType
(
out_dtype_
,
func
);
}
#else
ScaleLossGradFunctor
func
(
coeff_
,
tensor
,
place_
,
out_dtype_
,
nullptr
);
framework
::
VisitDataType
(
out_dtype_
,
func
);
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.h
浏览文件 @
287ca7d5
...
...
@@ -46,6 +46,12 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
;
platform
::
Place
GetPlace
()
const
{
return
place_
;
}
void
RunOnVar
(
Variable
*
var
,
bool
record_event
=
false
);
std
::
string
LossGradName
()
const
;
protected:
void
RunImpl
()
override
;
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
287ca7d5
...
...
@@ -22,7 +22,9 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -49,8 +51,29 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
PrepareLocalExeScopes
();
}
static
void
RunProgramDescs
(
const
ProgramDescs
&
programs
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
{
for
(
auto
&
program
:
programs
)
{
for
(
auto
&
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
size_t
i
=
0
;
i
<
local_exec_scopes
.
size
();
++
i
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op
->
Run
(
*
local_exec_scopes
[
i
],
places
[
i
]);
}
}
}
}
FetchResultType
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
IsCUDAGraphCapturing
())
{
strategy_
.
num_iteration_per_drop_scope_
=
std
::
numeric_limits
<
size_t
>::
max
();
DropLocalExeScopes
(
/*need_wait=*/
false
);
}
#endif
if
(
drop_scope_counter_
==
0
)
{
platform
::
RecordEvent
e
(
"InitLocalVars"
);
InitVariables
();
...
...
@@ -84,7 +107,7 @@ FetchResultType ScopeBufferedSSAGraphExecutor::Run(
++
drop_scope_counter_
;
if
(
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
||
DropScopeOrNot
())
{
DropLocalExeScopes
();
DropLocalExeScopes
(
!
platform
::
IsCUDAGraphCapturing
()
);
}
if
(
VLOG_IS_ON
(
5
))
{
...
...
@@ -128,15 +151,7 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
if
(
graph
.
Has
(
details
::
kStartupProgramDescs
))
{
auto
&
program_descs
=
graph
.
Get
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
);
for
(
auto
&
program_desc
:
program_descs
)
{
for
(
auto
&
op_desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
for
(
size_t
i
=
0
;
i
<
local_exec_scopes_
.
size
();
++
i
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op
->
Run
(
*
local_exec_scopes_
[
i
],
places_
[
i
]);
}
}
}
RunProgramDescs
(
program_descs
,
local_exec_scopes_
,
places_
);
}
is_initialized_
=
true
;
}
...
...
@@ -144,23 +159,17 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
if
(
graph
.
Has
(
details
::
kProgramDescs
))
{
auto
&
program_descs
=
graph
.
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
);
for
(
auto
&
program_desc
:
program_descs
)
{
for
(
auto
&
op_desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
for
(
size_t
i
=
0
;
i
<
local_exec_scopes_
.
size
();
++
i
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op
->
Run
(
*
local_exec_scopes_
[
i
],
places_
[
i
]);
}
}
}
RunProgramDescs
(
program_descs
,
local_exec_scopes_
,
places_
);
}
}
void
ScopeBufferedSSAGraphExecutor
::
DropLocalExeScopes
()
{
void
ScopeBufferedSSAGraphExecutor
::
DropLocalExeScopes
(
bool
need_wait
)
{
platform
::
RecordEvent
drop_scope_event
(
"DropLocalExeScopes"
);
drop_scope_counter_
=
0
;
for
(
auto
&
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
if
(
need_wait
)
{
for
(
auto
&
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
}
scope_monitor_
.
ClearHistoryLocalExecScopes
();
for
(
size_t
i
=
0
;
i
<
local_exec_scopes_
.
size
();
++
i
)
{
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
287ca7d5
...
...
@@ -53,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FetchResultType
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
,
bool
return_merged
)
override
;
void
DropLocalExeScopes
();
void
DropLocalExeScopes
(
bool
need_wait
=
true
);
bool
NeedCreateLocalExeScope
();
...
...
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
287ca7d5
...
...
@@ -115,6 +115,7 @@ message BuildStrategy {
optional
bool
enable_auto_fusion
=
11
[
default
=
false
];
optional
bool
enable_addto
=
12
[
default
=
false
];
optional
bool
fix_op_run_order
=
13
[
default
=
false
];
optional
bool
allow_cuda_graph_capture
=
14
[
default
=
false
];
}
message
ExecutionStrategy
{
...
...
paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
浏览文件 @
287ca7d5
...
...
@@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr
->
GeneratedOp
());
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if
(
right_generated_op
->
Name
()
!=
"conv2d_grad"
)
{
if
(
right_generated_op
->
Name
()
!=
"conv2d_grad"
&&
right_generated_op
->
Name
()
!=
"resnet_unit_grad"
)
{
continue
;
}
...
...
@@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if
(
node
.
inputs
.
empty
())
return
false
;
auto
*
generated_op
=
node
.
inputs
[
0
];
auto
*
op_desc
=
generated_op
->
Op
();
if
(
op_desc
==
nullptr
||
op_desc
->
Type
()
!=
"conv2d_grad"
)
{
if
(
op_desc
==
nullptr
||
(
op_desc
->
Type
()
!=
"conv2d_grad"
&&
op_desc
->
Type
()
!=
"resnet_unit_grad"
))
{
return
false
;
}
const
auto
&
outputs
=
op_desc
->
Outputs
();
auto
iter
=
outputs
.
find
(
GradVarName
(
"Input"
));
std
::
string
grad_var_name
=
op_desc
->
Type
()
==
"conv2d_grad"
?
"Input"
:
"X"
;
auto
iter
=
outputs
.
find
(
GradVarName
(
grad_var_name
));
return
iter
!=
outputs
.
end
()
&&
!
iter
->
second
.
empty
()
&&
iter
->
second
[
0
]
==
node
.
Name
()
&&
!
op_desc
->
GetAttrIfExists
<
bool
>
(
"use_addto"
);
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
浏览文件 @
287ca7d5
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle
scale_loss_grad_op_handle
op_graph_view multi_devices_helper
)
cc_library
(
multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper
)
cc_library
(
multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper
)
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc
浏览文件 @
287ca7d5
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
...
...
@@ -21,14 +22,23 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
template
<
typename
T
>
static
bool
IsMatchedPlaceSingleDeviceOp
(
details
::
OpHandleBase
*
op_base
,
const
platform
::
Place
&
place
)
{
auto
*
op
=
dynamic_cast
<
T
*>
(
op_base
);
return
op
&&
op
->
GetPlace
()
==
place
;
}
static
bool
IsLockAndRecordEventFreeComputationOpHandle
(
details
::
ComputationOpHandle
*
op
,
const
OpGraphView
&
graph_view
)
{
if
(
!
platform
::
is_gpu_place
(
op
->
GetPlace
())
&&
!
platform
::
is_xpu_place
(
op
->
GetPlace
()))
return
false
;
for
(
auto
&
pending_op
:
graph_view
.
PendingOps
(
op
))
{
auto
*
tmp
=
dynamic_cast
<
details
::
ComputationOpHandle
*>
(
pending_op
);
if
(
tmp
==
nullptr
||
!
(
tmp
->
GetPlace
()
==
op
->
GetPlace
()))
{
if
(
!
IsMatchedPlaceSingleDeviceOp
<
details
::
ComputationOpHandle
>
(
pending_op
,
op
->
GetPlace
())
&&
!
IsMatchedPlaceSingleDeviceOp
<
details
::
ScaleLossGradOpHandle
>
(
pending_op
,
op
->
GetPlace
()))
{
return
false
;
}
}
...
...
paddle/fluid/framework/operator_kernel_configs.h
浏览文件 @
287ca7d5
...
...
@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
287ca7d5
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
...
@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -43,6 +45,10 @@ limitations under the License. */
DECLARE_double
(
eager_delete_tensor_gb
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DECLARE_bool
(
sync_nccl_allreduce
);
#endif
#ifdef WITH_GPERFTOOLS
#include "gperftools/profiler.h"
#endif
...
...
@@ -669,6 +675,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// ncclOp
std
::
vector
<
ir
::
Graph
*>
async_graphs
=
CompileGraphWithBuildStrategy
(
graph
,
&
graphs
,
loss_var_name
);
PrepareForCUDAGraphCapture
(
graph
);
graph
=
member_
->
ApplyMemoryOptimizePass
(
graph
);
async_graphs
[
0
]
=
graph
;
...
...
@@ -882,6 +889,23 @@ void ParallelExecutor::BCastParamsToDevices(
FetchResultType
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
VLOG
(
3
)
<<
"enter ParallelExecutor Run"
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
IsCUDAGraphCapturing
())
{
PADDLE_ENFORCE_EQ
(
fetch_tensors
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot fetch data when using CUDA Graph."
));
PADDLE_ENFORCE_EQ
(
member_
->
build_strategy_
.
allow_cuda_graph_capture_
,
true
,
platform
::
errors
::
InvalidArgument
(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."
));
PADDLE_ENFORCE_EQ
(
member_
->
places_
[
0
],
platform
::
CUDAGraphCapturingPlace
(),
platform
::
errors
::
InvalidArgument
(
"The place to capture CUDAGraph is "
"not the same as the place to run."
));
}
#endif
#ifdef WITH_GPERFTOOLS
if
(
gProfileStarted
)
{
ProfilerFlush
();
...
...
@@ -932,6 +956,16 @@ void ParallelExecutor::SkipMemoryReuse(
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
const
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
LoDTensor
>>
&
tensors
)
{
if
(
platform
::
IsCUDAGraphCapturing
())
{
for
(
auto
&
tensor
:
tensors
)
{
PADDLE_ENFORCE_EQ
(
tensor
.
empty
(),
true
,
platform
::
errors
::
PermissionDenied
(
"Feeding data is not permitted when capturing CUDA Graph."
));
}
return
;
}
if
(
!
member_
->
AllowPartialFeed
())
{
PADDLE_ENFORCE_EQ
(
tensors
.
size
(),
member_
->
local_scopes_
.
size
(),
platform
::
errors
::
Unimplemented
(
...
...
@@ -987,6 +1021,14 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void
ParallelExecutor
::
FeedAndSplitTensorIntoLocalScopes
(
const
std
::
unordered_map
<
std
::
string
,
LoDTensor
>
&
tensors
)
{
if
(
platform
::
IsCUDAGraphCapturing
())
{
PADDLE_ENFORCE_EQ
(
tensors
.
empty
(),
true
,
platform
::
errors
::
PermissionDenied
(
"Feeding data is not permitted when capturing CUDA Graph."
));
return
;
}
size_t
num_places
=
member_
->
places_
.
size
();
bool
allow_partial_feed
=
member_
->
AllowPartialFeed
();
...
...
@@ -1568,6 +1610,107 @@ const ir::Graph &ParallelExecutor::Graph() const {
return
member_
->
executor_
->
Graph
();
}
void
ParallelExecutor
::
PrepareForCUDAGraphCapture
(
ir
::
Graph
*
graph
)
{
const
auto
&
build_strategy
=
member_
->
build_strategy_
;
if
(
!
build_strategy
.
allow_cuda_graph_capture_
)
return
;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ
(
build_strategy
.
async_mode_
,
false
,
platform
::
errors
::
InvalidArgument
(
"Async Executor does not support CUDA Graph capturing."
));
PADDLE_ENFORCE_EQ
(
platform
::
IsCUDAGraphCapturing
(),
false
,
platform
::
errors
::
PermissionDenied
(
"CUDA Graph is not allowed to capture "
"when running the first batch."
));
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"CUDA Graph is only supported when one GPU device is running."
));
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
member_
->
places_
[
0
]),
true
,
platform
::
errors
::
InvalidArgument
(
"CUDA Graph is only supported on NVIDIA GPU device."
));
PADDLE_ENFORCE_EQ
(
FLAGS_sync_nccl_allreduce
,
false
,
platform
::
errors
::
InvalidArgument
(
"FLAGS_sync_nccl_allreduce must be False to support "
"CUDA Graph capturing."
));
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarDesc
*>>
all_vars
;
for
(
auto
&
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
auto
*
var_desc
=
node
->
Var
();
all_vars
[
var_desc
->
Name
()].
emplace_back
(
var_desc
);
}
}
auto
mark_var_as_persistable
=
[
&
all_vars
](
const
std
::
string
&
name
)
{
auto
iter
=
all_vars
.
find
(
name
);
if
(
iter
!=
all_vars
.
end
())
{
for
(
auto
*
var_desc
:
iter
->
second
)
{
var_desc
->
SetPersistable
(
true
);
}
}
};
// Step 1: All fused vars must be persistable.
if
(
graph
->
Has
(
details
::
kFusedVars
))
{
auto
&
fused_vars
=
graph
->
Get
<
details
::
FusedVars
>
(
details
::
kFusedVars
);
for
(
auto
&
fused_var
:
fused_vars
)
{
fused_var
.
second
.
persistable_
=
true
;
mark_var_as_persistable
(
fused_var
.
first
);
}
}
// Step 2: All pinned vars must be persistable.
if
(
graph
->
Has
(
details
::
kPinnedVars
))
{
auto
&
pinned_vars
=
graph
->
Get
<
details
::
PinnedVars
>
(
details
::
kPinnedVars
);
for
(
auto
&
pinned_var
:
pinned_vars
)
{
mark_var_as_persistable
(
pinned_var
);
}
}
// Step 3: Move all main programs to startup programs to make sure that
// the main programs would only be run once.
if
(
graph
->
Has
(
details
::
kProgramDescs
))
{
auto
&
startup_programs
=
graph
->
GetOrInit
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
);
auto
&
main_programs
=
graph
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
);
for
(
auto
&
main_program
:
main_programs
)
{
startup_programs
.
emplace_back
(
main_program
);
}
graph
->
Erase
(
details
::
kProgramDescs
);
}
// Step 4: Mark all vars in startup programs to be persistable.
if
(
graph
->
Has
(
details
::
kStartupProgramDescs
))
{
auto
&
startup_programs
=
graph
->
GetOrInit
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
);
for
(
auto
&
startup_program
:
startup_programs
)
{
for
(
auto
&
op_desc
:
startup_program
.
Block
(
0
).
AllOps
())
{
for
(
auto
&
output
:
op_desc
->
OutputArgumentNames
())
{
mark_var_as_persistable
(
output
);
}
}
}
}
// Step 5: ScaleLossGrad must be run beforehand to avoid H2D copy.
auto
ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
auto
*
scope
=
member_
->
local_scopes_
[
0
];
for
(
auto
*
op
:
ops
)
{
auto
*
loss_grad_op
=
dynamic_cast
<
details
::
ScaleLossGradOpHandle
*>
(
op
);
if
(
loss_grad_op
==
nullptr
)
continue
;
auto
loss_grad_name
=
loss_grad_op
->
LossGradName
();
mark_var_as_persistable
(
loss_grad_name
);
loss_grad_op
->
RunOnVar
(
scope
->
Var
(
loss_grad_name
));
loss_grad_op
->
SetSkipRunning
(
true
);
}
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"CUDA Graph is only supported on NVIDIA GPU device."
));
#endif
}
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
287ca7d5
...
...
@@ -144,6 +144,8 @@ class ParallelExecutor {
void
SetReaderOpDeviceInfoOfGraphs
(
const
std
::
vector
<
ir
::
Graph
*>
&
final_graphs
);
void
PrepareForCUDAGraphCapture
(
ir
::
Graph
*
graph
);
ParallelExecutorPrivate
*
member_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
async_graphs_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
...
...
paddle/fluid/memory/allocation/CMakeLists.txt
浏览文件 @
287ca7d5
...
...
@@ -82,7 +82,11 @@ endif()
cc_library
(
aligned_allocator SRCS aligned_allocator.cc DEPS allocator
)
cc_test
(
test_aligned_allocator SRCS test_aligned_allocator.cc DEPS aligned_allocator
)
cc_library
(
allocator_strategy SRCS allocator_strategy.cc DEPS gflags
${
AllocatorFacadeDeps
}
)
cc_library
(
allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy
)
cc_library
(
allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy
)
if
(
WITH_GPU
)
target_link_libraries
(
allocator_facade cuda_graph
)
endif
()
cc_test
(
retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator
)
if
(
WITH_TESTING
)
...
...
paddle/fluid/memory/allocation/allocator_facade.cc
浏览文件 @
287ca7d5
...
...
@@ -32,6 +32,9 @@
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph.h"
#endif
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu/xpu_info.h"
#endif
...
...
@@ -47,17 +50,64 @@ PADDLE_DEFINE_EXPORTED_bool(
"Whether to use system allocator to allocate CPU and GPU memory. "
"Only used for unittests."
);
DECLARE_string
(
allocator_strategy
);
namespace
paddle
{
namespace
memory
{
namespace
allocation
{
#ifdef PADDLE_WITH_CUDA
class
CUDAGraphAllocator
:
public
Allocator
,
public
std
::
enable_shared_from_this
<
CUDAGraphAllocator
>
{
private:
class
PrivateAllocation
:
public
Allocation
{
public:
PrivateAllocation
(
CUDAGraphAllocator
*
allocator
,
AllocationPtr
underlying_allocation
)
:
Allocation
(
underlying_allocation
->
ptr
(),
underlying_allocation
->
size
(),
underlying_allocation
->
place
()),
allocator_
(
allocator
->
shared_from_this
()),
underlying_allocation_
(
std
::
move
(
underlying_allocation
))
{}
private:
std
::
shared_ptr
<
Allocator
>
allocator_
;
AllocationPtr
underlying_allocation_
;
};
explicit
CUDAGraphAllocator
(
const
std
::
shared_ptr
<
Allocator
>&
allocator
)
:
underlying_allocator_
(
allocator
)
{}
public:
static
std
::
shared_ptr
<
Allocator
>
Create
(
const
std
::
shared_ptr
<
Allocator
>&
allocator
)
{
return
std
::
shared_ptr
<
Allocator
>
(
new
CUDAGraphAllocator
(
allocator
));
}
protected:
Allocation
*
AllocateImpl
(
size_t
size
)
{
VLOG
(
10
)
<<
"Allocate "
<<
size
<<
" for CUDA Graph"
;
return
new
PrivateAllocation
(
this
,
underlying_allocator_
->
Allocate
(
size
));
}
void
FreeImpl
(
Allocation
*
allocation
)
{
VLOG
(
10
)
<<
"delete for CUDA Graph"
;
delete
allocation
;
}
private:
std
::
shared_ptr
<
Allocator
>
underlying_allocator_
;
};
#endif
class
AllocatorFacadePrivate
{
public:
using
AllocatorMap
=
std
::
map
<
platform
::
Place
,
std
::
shared_ptr
<
Allocator
>>
;
AllocatorFacadePrivate
(
)
{
auto
strategy
=
GetAllocatorStrategy
();
switch
(
strategy
)
{
explicit
AllocatorFacadePrivate
(
bool
allow_free_idle_chunk
=
true
)
{
strategy_
=
GetAllocatorStrategy
();
switch
(
strategy
_
)
{
case
AllocatorStrategy
::
kNaiveBestFit
:
{
InitNaiveBestFitCPUAllocator
();
#ifdef PADDLE_WITH_XPU
...
...
@@ -91,7 +141,8 @@ class AllocatorFacadePrivate {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
for
(
int
dev_id
=
0
;
dev_id
<
platform
::
GetCUDADeviceCount
();
++
dev_id
)
{
InitAutoGrowthCUDAAllocator
(
platform
::
CUDAPlace
(
dev_id
));
InitAutoGrowthCUDAAllocator
(
platform
::
CUDAPlace
(
dev_id
),
allow_free_idle_chunk
);
}
InitNaiveBestFitCUDAPinnedAllocator
();
#endif
...
...
@@ -117,7 +168,7 @@ class AllocatorFacadePrivate {
default:
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported allocator strategy: %d"
,
static_cast
<
int
>
(
strategy
)));
"Unsupported allocator strategy: %d"
,
static_cast
<
int
>
(
strategy
_
)));
}
}
InitZeroSizeAllocators
();
...
...
@@ -130,11 +181,29 @@ class AllocatorFacadePrivate {
CheckAllocThreadSafe
();
}
inline
const
AllocatorMap
&
GetAllocatorMap
()
{
#ifdef PADDLE_WITH_CUDA
if
(
UNLIKELY
(
platform
::
CUDAGraph
::
IsCapturing
()))
{
auto
id
=
platform
::
CUDAGraph
::
CapturingID
();
auto
iter
=
cuda_graph_allocator_map_
.
find
(
id
);
PADDLE_ENFORCE_NE
(
iter
,
cuda_graph_allocator_map_
.
end
(),
platform
::
errors
::
PermissionDenied
(
"No memory pool is prepared for CUDA Graph capturing."
));
return
iter
->
second
->
allocators_
;
}
else
{
return
allocators_
;
}
#else
return
allocators_
;
#endif
}
inline
const
std
::
shared_ptr
<
Allocator
>&
GetAllocator
(
const
platform
::
Place
&
place
,
size_t
size
)
{
const
auto
&
allocators
=
(
size
>
0
?
(
UNLIKELY
(
FLAGS_use_system_allocator
)
?
system_allocators_
:
allocators_
)
:
GetAllocatorMap
()
)
:
zero_size_allocators_
);
auto
iter
=
allocators
.
find
(
place
);
PADDLE_ENFORCE_NE
(
iter
,
allocators
.
end
(),
...
...
@@ -145,6 +214,7 @@ class AllocatorFacadePrivate {
private:
void
InitSystemAllocators
()
{
if
(
!
system_allocators_
.
empty
())
return
;
system_allocators_
[
platform
::
CPUPlace
()]
=
std
::
make_shared
<
CPUAllocator
>
();
#ifdef PADDLE_WITH_XPU
int
device_count
=
platform
::
GetXPUDeviceCount
();
...
...
@@ -183,10 +253,11 @@ class AllocatorFacadePrivate {
allocators_
[
p
]
=
std
::
make_shared
<
ThreadLocalCUDAAllocator
>
(
p
);
}
void
InitAutoGrowthCUDAAllocator
(
platform
::
CUDAPlace
p
)
{
void
InitAutoGrowthCUDAAllocator
(
platform
::
CUDAPlace
p
,
bool
allow_free_idle_chunk
)
{
auto
cuda_allocator
=
std
::
make_shared
<
CUDAAllocator
>
(
p
);
allocators_
[
p
]
=
std
::
make_shared
<
AutoGrowthBestFitAllocator
>
(
cuda_allocator
,
platform
::
GpuMinChunkSize
());
cuda_allocator
,
platform
::
GpuMinChunkSize
()
,
allow_free_idle_chunk
);
}
#endif
...
...
@@ -226,6 +297,7 @@ class AllocatorFacadePrivate {
};
void
InitZeroSizeAllocators
()
{
if
(
!
zero_size_allocators_
.
empty
())
return
;
std
::
vector
<
platform
::
Place
>
places
;
places
.
emplace_back
(
platform
::
CPUPlace
());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
...
@@ -279,12 +351,57 @@ class AllocatorFacadePrivate {
}
}
#ifdef PADDLE_WITH_CUDA
public:
void
PrepareMemoryPoolForCUDAGraph
(
CUDAGraphID
id
)
{
PADDLE_ENFORCE_EQ
(
strategy_
,
AllocatorStrategy
::
kAutoGrowth
,
platform
::
errors
::
InvalidArgument
(
"CUDA Graph is only supported when the "
"FLAGS_allocator_strategy=
\"
auto_growth
\"
, but got "
"FLAGS_allocator_strategy=
\"
%s
\"
"
,
FLAGS_allocator_strategy
));
auto
&
allocator
=
cuda_graph_allocator_map_
[
id
];
PADDLE_ENFORCE_EQ
(
allocator
.
get
(),
nullptr
,
platform
::
errors
::
InvalidArgument
(
"The memory pool of the CUDA Graph with ID %d have been prepared."
,
id
));
allocator
.
reset
(
new
AllocatorFacadePrivate
(
/*allow_free_idle_chunk=*/
false
));
for
(
auto
&
item
:
allocator
->
allocators_
)
{
auto
&
old_allocator
=
item
.
second
;
old_allocator
=
CUDAGraphAllocator
::
Create
(
old_allocator
);
}
VLOG
(
10
)
<<
"Prepare memory pool for CUDA Graph with ID "
<<
id
;
}
void
RemoveMemoryPoolOfCUDAGraph
(
CUDAGraphID
id
)
{
auto
iter
=
cuda_graph_allocator_map_
.
find
(
id
);
PADDLE_ENFORCE_NE
(
iter
,
cuda_graph_allocator_map_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find CUDA Graph with ID = %d"
,
id
));
cuda_graph_allocator_map_
.
erase
(
iter
);
VLOG
(
10
)
<<
"Remove memory pool of CUDA Graph with ID "
<<
id
;
}
#endif
private:
AllocatorMap
allocators_
;
AllocatorMap
zero_size_allocators_
;
AllocatorMap
system_allocators_
;
#ifdef PADDLE_WITH_CUDA
std
::
unordered_map
<
CUDAGraphID
,
std
::
unique_ptr
<
AllocatorFacadePrivate
>>
cuda_graph_allocator_map_
;
#endif
AllocatorStrategy
strategy_
;
static
AllocatorMap
zero_size_allocators_
;
static
AllocatorMap
system_allocators_
;
};
AllocatorFacadePrivate
::
AllocatorMap
AllocatorFacadePrivate
::
zero_size_allocators_
;
AllocatorFacadePrivate
::
AllocatorMap
AllocatorFacadePrivate
::
system_allocators_
;
// Pimpl. Make interface clean.
AllocatorFacade
::
AllocatorFacade
()
:
m_
(
new
AllocatorFacadePrivate
())
{}
// delete m_ may cause core dump when the destructor of python in conflict with
...
...
@@ -316,6 +433,16 @@ const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
return
m_
->
GetAllocator
(
place
,
/* A non-zero num to choose allocator_ */
1
);
}
#ifdef PADDLE_WITH_CUDA
void
AllocatorFacade
::
PrepareMemoryPoolForCUDAGraph
(
CUDAGraphID
id
)
{
return
m_
->
PrepareMemoryPoolForCUDAGraph
(
id
);
}
void
AllocatorFacade
::
RemoveMemoryPoolOfCUDAGraph
(
CUDAGraphID
id
)
{
return
m_
->
RemoveMemoryPoolOfCUDAGraph
(
id
);
}
#endif
}
// namespace allocation
}
// namespace memory
}
// namespace paddle
paddle/fluid/memory/allocation/allocator_facade.h
浏览文件 @
287ca7d5
...
...
@@ -18,6 +18,9 @@
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/memory/allocation/npu_pinned_allocator.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
...
...
@@ -54,6 +57,11 @@ class AllocatorFacade {
uint64_t
Release
(
const
platform
::
Place
&
place
);
const
std
::
shared_ptr
<
Allocator
>&
GetAllocator
(
const
platform
::
Place
&
place
);
#ifdef PADDLE_WITH_CUDA
void
PrepareMemoryPoolForCUDAGraph
(
CUDAGraphID
id
);
void
RemoveMemoryPoolOfCUDAGraph
(
CUDAGraphID
id
);
#endif
// TODO(yy): Allocate a Copy-On-Write allocation?
private:
AllocatorFacade
();
...
...
paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc
浏览文件 @
287ca7d5
...
...
@@ -39,11 +39,12 @@ namespace allocation {
AutoGrowthBestFitAllocator
::
AutoGrowthBestFitAllocator
(
const
std
::
shared_ptr
<
Allocator
>
&
underlying_allocator
,
size_t
alignment
,
size_t
chunk_size
)
size_t
chunk_size
,
bool
allow_free_idle_chunk
)
:
underlying_allocator_
(
std
::
make_shared
<
AlignedAllocator
>
(
underlying_allocator
,
alignment
)),
alignment_
(
alignment
),
chunk_size_
(
std
::
max
(
AlignedSize
(
chunk_size
,
alignment
),
alignment
))
{}
chunk_size_
(
std
::
max
(
AlignedSize
(
chunk_size
,
alignment
),
alignment
)),
allow_free_idle_chunk_
(
allow_free_idle_chunk
)
{}
Allocation
*
AutoGrowthBestFitAllocator
::
AllocateImpl
(
size_t
size
)
{
size
=
AlignedSize
(
size
,
alignment_
);
...
...
@@ -139,6 +140,9 @@ void AutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) {
}
uint64_t
AutoGrowthBestFitAllocator
::
FreeIdleChunks
()
{
if
(
!
allow_free_idle_chunk_
)
{
return
0
;
}
uint64_t
bytes
=
0
;
for
(
auto
chunk_it
=
chunks_
.
begin
();
chunk_it
!=
chunks_
.
end
();)
{
auto
&
blocks
=
chunk_it
->
blocks_
;
...
...
paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h
浏览文件 @
287ca7d5
...
...
@@ -31,7 +31,7 @@ class AutoGrowthBestFitAllocator : public Allocator {
public:
AutoGrowthBestFitAllocator
(
const
std
::
shared_ptr
<
Allocator
>
&
underlying_allocator
,
size_t
alignment
,
size_t
chunk_size
=
0
);
size_t
chunk_size
=
0
,
bool
allow_free_idle_chunk
=
true
);
bool
IsAllocThreadSafe
()
const
override
{
return
true
;
}
...
...
@@ -86,6 +86,7 @@ class AutoGrowthBestFitAllocator : public Allocator {
std
::
list
<
Chunk
>
chunks_
;
size_t
alignment_
;
size_t
chunk_size_
;
bool
allow_free_idle_chunk_
;
SpinLock
spinlock_
;
};
...
...
paddle/fluid/operators/conv_cudnn_helper.h
浏览文件 @
287ca7d5
...
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -480,6 +481,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
static
algo_t
Find
(
const
ConvArgs
&
args
,
bool
exhaustive_search
,
bool
deterministic
,
const
framework
::
ExecutionContext
&
ctx
)
{
platform
::
CUDAGraphCaptureModeGuard
guard
;
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
size_t
workspace_size_limit
=
FLAGS_conv_workspace_size_limit
*
1024
*
1024
;
size_t
workspace_size
=
0
;
...
...
@@ -601,6 +603,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
}
static
size_t
GetWorkspaceSize
(
const
ConvArgs
&
args
,
algo_t
algo
)
{
platform
::
CUDAGraphCaptureModeGuard
guard
;
size_t
workspace_size
=
0
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
287ca7d5
...
...
@@ -18,7 +18,8 @@ register_operators(EXCLUDES
fused_bn_add_activation_op
fused_attention_op
fused_feedforward_op
fused_transformer_op
)
fused_transformer_op
resnet_unit_op
)
# fusion_gru_op does not have CUDA kernel
op_library
(
fusion_gru_op
)
...
...
@@ -86,4 +87,11 @@ if (WITH_GPU OR WITH_ROCM)
op_library
(
fused_attention_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(fused_attention);
\n
"
)
endif
()
# resnet_unit needs cudnn 8.0 above
if
((
NOT WITH_ROCM
)
AND
(
NOT
${
CUDNN_VERSION
}
VERSION_LESS 8000
))
op_library
(
resnet_unit_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(resnet_unit);
\n
"
)
cc_test
(
test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory
)
cc_test
(
test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory
)
endif
()
endif
()
paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_batchnorm_spatial_persistent
);
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
op
=
paddle
::
operators
;
using
Tensor
=
paddle
::
framework
::
Tensor
;
USE_OP
(
batch_norm
);
USE_CUDA_ONLY_OP
(
fused_bn_add_activation
);
USE_CUDA_ONLY_OP
(
fused_bn_add_activation_grad
);
template
<
typename
T
>
void
InitRandomTensor
(
const
std
::
vector
<
int64_t
>
&
dims
,
framework
::
Tensor
*
cpu_out
)
{
T
*
cpu_out_ptr
=
cpu_out
->
mutable_data
<
T
>
(
framework
::
make_ddim
(
dims
),
platform
::
CPUPlace
());
std
::
default_random_engine
random
(
0
);
std
::
uniform_real_distribution
<
float
>
dis
(
-
1.0
,
1.0
);
for
(
int
i
=
0
;
i
<
cpu_out
->
numel
();
++
i
)
{
cpu_out_ptr
[
i
]
=
static_cast
<
T
>
(
dis
(
random
));
}
}
template
<
typename
T
>
void
InitConstantTensor
(
const
std
::
vector
<
int64_t
>
&
dims
,
T
value
,
framework
::
Tensor
*
cpu_out
)
{
T
*
cpu_out_ptr
=
cpu_out
->
mutable_data
<
T
>
(
framework
::
make_ddim
(
dims
),
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
cpu_out
->
numel
();
++
i
)
{
cpu_out_ptr
[
i
]
=
value
;
}
}
template
<
typename
T
>
void
CheckOutput
(
std
::
string
name
,
const
framework
::
Tensor
&
cpu_res
,
const
framework
::
Tensor
&
cpu_base
,
float
diff
,
bool
is_relative_atol
=
false
)
{
if
(
cpu_res
.
dims
().
size
()
==
cpu_base
.
dims
().
size
())
{
EXPECT_EQ
(
cpu_res
.
dims
(),
cpu_base
.
dims
());
}
else
{
EXPECT_EQ
(
cpu_res
.
numel
(),
cpu_base
.
numel
());
}
const
T
*
cpu_res_ptr
=
cpu_res
.
data
<
T
>
();
const
T
*
cpu_base_ptr
=
cpu_base
.
data
<
T
>
();
float
max_diff
=
0
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
cpu_res
.
numel
();
++
i
)
{
float
cur_diff
;
if
(
is_relative_atol
)
{
cur_diff
=
static_cast
<
float
>
(
std
::
abs
((
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
])
/
cpu_base_ptr
[
i
]));
EXPECT_LT
(
static_cast
<
float
>
(
std
::
abs
((
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
])
/
cpu_base_ptr
[
i
])),
diff
);
}
else
{
cur_diff
=
static_cast
<
float
>
(
std
::
abs
(
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
]));
EXPECT_LT
(
static_cast
<
float
>
(
std
::
abs
(
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
])),
diff
);
}
if
(
cur_diff
>
max_diff
)
{
max_diff
=
cur_diff
;
index
=
i
;
}
}
std
::
string
error_type
=
is_relative_atol
?
"relative"
:
"absolute"
;
LOG
(
INFO
)
<<
"["
<<
name
<<
"] The dims is ["
<<
cpu_res
.
dims
()
<<
"], maximum "
<<
error_type
<<
" error is "
<<
max_diff
<<
": "
<<
cpu_res_ptr
[
index
]
<<
" vs "
<<
cpu_base_ptr
[
index
];
}
template
<
typename
T
>
void
ComputeSumAndSquareSum
(
const
framework
::
Tensor
&
cpu_x
,
framework
::
Tensor
*
cpu_sum
,
framework
::
Tensor
*
cpu_sum_of_square
)
{
// x is in NHWC format.
auto
dims
=
cpu_x
.
dims
();
int64_t
c
=
dims
[
3
];
const
T
*
cpu_x_ptr
=
cpu_x
.
data
<
T
>
();
float
*
cpu_sum_ptr
=
cpu_sum
->
mutable_data
<
float
>
({
1
,
1
,
1
,
c
},
platform
::
CPUPlace
());
float
*
cpu_sum_square_ptr
=
cpu_sum_of_square
->
mutable_data
<
float
>
(
{
1
,
1
,
1
,
c
},
platform
::
CPUPlace
());
for
(
int
j
=
0
;
j
<
c
;
++
j
)
{
float
tmp_sum
=
0.0
f
;
float
tmp_sum_of_squares
=
0.0
f
;
for
(
int
i
=
0
;
i
<
cpu_x
.
numel
()
/
c
;
++
i
)
{
float
tmp_x
=
static_cast
<
float
>
(
cpu_x_ptr
[
i
*
c
+
j
]);
tmp_sum
+=
tmp_x
;
tmp_sum_of_squares
+=
tmp_x
*
tmp_x
;
}
cpu_sum_ptr
[
j
]
=
tmp_sum
;
cpu_sum_square_ptr
[
j
]
=
tmp_sum_of_squares
;
}
}
template
<
typename
T
>
void
ComputeInplaceAdd
(
const
framework
::
Tensor
&
cpu_x
,
framework
::
Tensor
*
cpu_y
)
{
EXPECT_EQ
(
cpu_x
.
dims
(),
cpu_y
->
dims
());
const
T
*
cpu_x_ptr
=
cpu_x
.
data
<
T
>
();
T
*
cpu_y_ptr
=
cpu_y
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
cpu_x
.
numel
();
++
i
)
{
cpu_y_ptr
[
i
]
+=
cpu_x_ptr
[
i
];
}
}
template
<
typename
T
>
void
ComputeInplaceRelu
(
framework
::
Tensor
*
cpu_x
)
{
T
*
cpu_x_ptr
=
cpu_x
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
cpu_x
->
numel
();
++
i
)
{
cpu_x_ptr
[
i
]
=
cpu_x_ptr
[
i
]
>
static_cast
<
T
>
(
0
)
?
cpu_x_ptr
[
i
]
:
static_cast
<
T
>
(
0
);
}
}
void
ComputeBatchNormForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_x
,
const
Tensor
&
cpu_scale
,
const
Tensor
&
cpu_bias
,
Tensor
*
cpu_mean
,
Tensor
*
cpu_var
,
Tensor
*
cpu_saved_mean
,
Tensor
*
cpu_saved_var
,
Tensor
*
cpu_y
,
Tensor
*
saved_reserve_space
)
{
framework
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"X"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
scale
=
scope
.
Var
(
"Scale"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias
=
scope
.
Var
(
"Bias"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
mean
=
scope
.
Var
(
"Mean"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var
=
scope
.
Var
(
"Variance"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
y
=
scope
.
Var
(
"Y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_mean
=
scope
.
Var
(
"SavedMean"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_var
=
scope
.
Var
(
"SavedVariance"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
reserve_space
=
scope
.
Var
(
"ReserveSpace"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_x
,
place
,
x
);
TensorCopySync
(
cpu_scale
,
place
,
scale
);
TensorCopySync
(
cpu_bias
,
place
,
bias
);
TensorCopySync
(
*
cpu_mean
,
place
,
mean
);
TensorCopySync
(
*
cpu_var
,
place
,
var
);
int64_t
channels
=
x
->
dims
()[
3
];
scale
->
Resize
({
channels
});
bias
->
Resize
({
channels
});
mean
->
Resize
({
channels
});
var
->
Resize
({
channels
});
framework
::
AttributeMap
attrs
;
std
::
string
data_layout
=
"NHWC"
;
attrs
.
insert
({
"data_layout"
,
data_layout
});
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"batch_norm"
,
{{
"X"
,
{
"X"
}},
{
"Scale"
,
{
"Scale"
}},
{
"Bias"
,
{
"Bias"
}},
{
"Mean"
,
{
"Mean"
}},
{
"Variance"
,
{
"Variance"
}}},
{{
"Y"
,
{
"Y"
}},
{
"MeanOut"
,
{
"Mean"
}},
{
"VarianceOut"
,
{
"Variance"
}},
{
"SavedMean"
,
{
"SavedMean"
}},
{
"SavedVariance"
,
{
"SavedVariance"
}},
{
"ReserveSpace"
,
{
"ReserveSpace"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
TensorCopySync
(
*
y
,
platform
::
CPUPlace
(),
cpu_y
);
TensorCopySync
(
*
mean
,
platform
::
CPUPlace
(),
cpu_mean
);
TensorCopySync
(
*
var
,
platform
::
CPUPlace
(),
cpu_var
);
TensorCopySync
(
*
saved_mean
,
platform
::
CPUPlace
(),
cpu_saved_mean
);
TensorCopySync
(
*
saved_var
,
platform
::
CPUPlace
(),
cpu_saved_var
);
// reserved_space will stay on GPU and used in grad op.
saved_reserve_space
->
ShareDataWith
(
*
reserve_space
);
}
void
ComputeFusedBNAddReluForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_x
,
const
Tensor
&
cpu_z
,
const
Tensor
&
cpu_scale
,
const
Tensor
&
cpu_bias
,
Tensor
*
cpu_mean
,
Tensor
*
cpu_var
,
Tensor
*
cpu_saved_mean
,
Tensor
*
cpu_saved_var
,
Tensor
*
cpu_y
,
Tensor
*
saved_reserve_space
)
{
framework
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"X"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
z
=
scope
.
Var
(
"Z"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
scale
=
scope
.
Var
(
"Scale"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias
=
scope
.
Var
(
"Bias"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
mean
=
scope
.
Var
(
"Mean"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var
=
scope
.
Var
(
"Variance"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
y
=
scope
.
Var
(
"Y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_mean
=
scope
.
Var
(
"SavedMean"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_var
=
scope
.
Var
(
"SavedVariance"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
reserve_space
=
scope
.
Var
(
"ReserveSpace"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_x
,
place
,
x
);
TensorCopySync
(
cpu_z
,
place
,
z
);
TensorCopySync
(
cpu_scale
,
place
,
scale
);
TensorCopySync
(
cpu_bias
,
place
,
bias
);
TensorCopySync
(
*
cpu_mean
,
place
,
mean
);
TensorCopySync
(
*
cpu_var
,
place
,
var
);
int64_t
channels
=
x
->
dims
()[
3
];
scale
->
Resize
({
channels
});
bias
->
Resize
({
channels
});
mean
->
Resize
({
channels
});
var
->
Resize
({
channels
});
framework
::
AttributeMap
attrs
;
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"fused_bn_add_activation"
,
{{
"X"
,
{
"X"
}},
{
"Z"
,
{
"Z"
}},
{
"Scale"
,
{
"Scale"
}},
{
"Bias"
,
{
"Bias"
}}},
{{
"Y"
,
{
"Y"
}},
{
"MeanOut"
,
{
"Mean"
}},
{
"VarianceOut"
,
{
"Variance"
}},
{
"SavedMean"
,
{
"SavedMean"
}},
{
"SavedVariance"
,
{
"SavedVariance"
}},
{
"ReserveSpace"
,
{
"ReserveSpace"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
TensorCopySync
(
*
y
,
platform
::
CPUPlace
(),
cpu_y
);
TensorCopySync
(
*
mean
,
platform
::
CPUPlace
(),
cpu_mean
);
TensorCopySync
(
*
var
,
platform
::
CPUPlace
(),
cpu_var
);
TensorCopySync
(
*
saved_mean
,
platform
::
CPUPlace
(),
cpu_saved_mean
);
TensorCopySync
(
*
saved_var
,
platform
::
CPUPlace
(),
cpu_saved_var
);
// reserved_space will stay on GPU and used in grad op.
saved_reserve_space
->
ShareDataWith
(
*
reserve_space
);
}
void
ComputeFusedBNAddReluBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_dy
,
const
Tensor
&
cpu_x
,
const
Tensor
&
cpu_scale
,
const
Tensor
&
cpu_bias
,
const
Tensor
&
cpu_saved_mean
,
const
Tensor
&
cpu_saved_var
,
const
Tensor
&
cpu_y
,
const
Tensor
&
saved_reserve_space
,
Tensor
*
cpu_dx
,
Tensor
*
cpu_dz
,
Tensor
*
cpu_dscale
,
Tensor
*
cpu_dbias
)
{
framework
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"X"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
y
=
scope
.
Var
(
"Y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
dy
=
scope
.
Var
(
"Y@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
scale
=
scope
.
Var
(
"Scale"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias
=
scope
.
Var
(
"Bias"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_mean
=
scope
.
Var
(
"SavedMean"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
saved_var
=
scope
.
Var
(
"SavedVariance"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
reserve_space
=
scope
.
Var
(
"ReserveSpace"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
dx
=
scope
.
Var
(
"X@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
dz
=
scope
.
Var
(
"Z@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
dscale
=
scope
.
Var
(
"Scale@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
dbias
=
scope
.
Var
(
"Bias@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_x
,
place
,
x
);
TensorCopySync
(
cpu_y
,
place
,
y
);
TensorCopySync
(
cpu_dy
,
place
,
dy
);
TensorCopySync
(
cpu_scale
,
place
,
scale
);
TensorCopySync
(
cpu_bias
,
place
,
bias
);
TensorCopySync
(
cpu_saved_mean
,
place
,
saved_mean
);
TensorCopySync
(
cpu_saved_var
,
place
,
saved_var
);
reserve_space
->
ShareDataWith
(
saved_reserve_space
);
int64_t
channels
=
x
->
dims
()[
3
];
scale
->
Resize
({
channels
});
bias
->
Resize
({
channels
});
saved_mean
->
Resize
({
channels
});
saved_var
->
Resize
({
channels
});
framework
::
AttributeMap
attrs
;
float
momentum
=
0.9
;
float
epsilon
=
1e-5
;
std
::
string
act_type
=
"relu"
;
attrs
.
insert
({
"momentum"
,
momentum
});
attrs
.
insert
({
"epsilon"
,
epsilon
});
attrs
.
insert
({
"act_type"
,
act_type
});
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"fused_bn_add_activation_grad"
,
{{
"X"
,
{
"X"
}},
{
"Y"
,
{
"Y"
}},
{
"Y@GRAD"
,
{
"Y@GRAD"
}},
{
"Scale"
,
{
"Scale"
}},
{
"Bias"
,
{
"Bias"
}},
{
"SavedMean"
,
{
"SavedMean"
}},
{
"SavedVariance"
,
{
"SavedVariance"
}},
{
"ReserveSpace"
,
{
"ReserveSpace"
}}},
{{
"X@GRAD"
,
{
"X@GRAD"
}},
{
"Z@GRAD"
,
{
"Z@GRAD"
}},
{
"Scale@GRAD"
,
{
"Scale@GRAD"
}},
{
"Bias@GRAD"
,
{
"Bias@GRAD"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
TensorCopySync
(
*
dx
,
platform
::
CPUPlace
(),
cpu_dx
);
TensorCopySync
(
*
dz
,
platform
::
CPUPlace
(),
cpu_dz
);
TensorCopySync
(
*
dscale
,
platform
::
CPUPlace
(),
cpu_dscale
);
TensorCopySync
(
*
dbias
,
platform
::
CPUPlace
(),
cpu_dbias
);
}
template
<
typename
T
>
class
CudnnBNAddReluTester
{
public:
CudnnBNAddReluTester
(
int
batch_size
,
int
height
,
int
width
,
int
channels
,
std
::
string
act_type
,
bool
fuse_add
,
bool
has_shortcut
)
{
batch_size_
=
batch_size
;
height_
=
height
;
width_
=
width
;
channels_
=
channels
;
ele_count_
=
batch_size_
*
height_
*
width_
;
act_type_
=
act_type
;
fuse_add_
=
fuse_add
;
has_shortcut_
=
has_shortcut
;
SetUp
();
}
~
CudnnBNAddReluTester
()
{}
void
CheckForward
(
float
diff
,
bool
is_relative_atol
=
false
)
{
LOG
(
INFO
)
<<
"[CheckForward, diff="
<<
diff
<<
", is_relative_atol="
<<
is_relative_atol
<<
"] act_type="
<<
act_type_
<<
", fuse_add="
<<
fuse_add_
<<
", has_shortcut="
<<
has_shortcut_
;
platform
::
CUDADeviceContext
*
ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
auto
select
=
[
&
](
Tensor
*
in
)
{
return
has_shortcut_
?
in
:
nullptr
;
};
framework
::
Tensor
cpu_mean_base_x
;
framework
::
Tensor
cpu_var_base_x
;
framework
::
Tensor
cpu_mean_base_z
;
framework
::
Tensor
cpu_var_base_z
;
if
(
!
has_shortcut_
&&
fuse_add_
&&
(
act_type_
==
"relu"
))
{
BaselineForwardFusedBNAddRelu
(
*
ctx
,
&
cpu_mean_base_x
,
&
cpu_var_base_x
,
&
cpu_saved_mean_base_x_
,
&
cpu_saved_var_base_x_
,
&
cpu_y_base_
,
&
saved_reserve_space_x_
);
}
else
{
BaselineForward
(
*
ctx
,
&
cpu_mean_base_x
,
&
cpu_var_base_x
,
&
cpu_saved_mean_base_x_
,
&
cpu_saved_var_base_x_
,
&
cpu_y_base_
,
&
saved_reserve_space_x_
,
select
(
&
cpu_mean_base_z
),
select
(
&
cpu_var_base_z
),
select
(
&
cpu_saved_mean_base_z_
),
select
(
&
cpu_saved_var_base_z_
),
select
(
&
saved_reserve_space_z_
));
}
framework
::
Tensor
cpu_mean_x
;
framework
::
Tensor
cpu_var_x
;
framework
::
Tensor
cpu_y
;
framework
::
Tensor
cpu_mean_z
;
framework
::
Tensor
cpu_var_z
;
FusedForward
(
*
ctx
,
&
cpu_mean_x
,
&
cpu_var_x
,
&
cpu_saved_mean_x_
,
&
cpu_saved_var_x_
,
&
cpu_y
,
&
cpu_bitmask_
,
select
(
&
cpu_mean_z
),
select
(
&
cpu_var_z
),
select
(
&
cpu_saved_mean_z_
),
select
(
&
cpu_saved_var_z_
));
CheckOutput
<
float
>
(
"Mean"
,
cpu_mean_x
,
cpu_mean_base_x
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"Variance"
,
cpu_var_x
,
cpu_var_base_x
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"SavedMean"
,
cpu_saved_mean_x_
,
cpu_saved_mean_base_x_
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"SavedVariance"
,
cpu_saved_var_x_
,
cpu_saved_var_base_x_
,
diff
,
is_relative_atol
);
if
(
has_shortcut_
)
{
CheckOutput
<
float
>
(
"MeanZ"
,
cpu_mean_z
,
cpu_mean_base_z
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"VarianceZ"
,
cpu_var_z
,
cpu_var_base_z
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"SavedMeanZ"
,
cpu_saved_mean_z_
,
cpu_saved_mean_base_z_
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"SavedVarianceZ"
,
cpu_saved_var_z_
,
cpu_saved_var_base_z_
,
diff
,
is_relative_atol
);
}
CheckOutput
<
T
>
(
"Y"
,
cpu_y
,
cpu_y_base_
,
diff
,
is_relative_atol
);
}
void
CheckBackward
(
float
diff
,
bool
is_relative_atol
=
false
)
{
platform
::
CUDADeviceContext
*
ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
framework
::
Tensor
cpu_dx_base
;
framework
::
Tensor
cpu_dz_base
;
framework
::
Tensor
cpu_dscale_base
;
framework
::
Tensor
cpu_dbias_base
;
BaselineBackwardFusedBNAddRelu
(
*
ctx
,
&
cpu_dx_base
,
&
cpu_dz_base
,
&
cpu_dscale_base
,
&
cpu_dbias_base
);
framework
::
Tensor
cpu_dx
;
framework
::
Tensor
cpu_dz
;
framework
::
Tensor
cpu_dscale
;
framework
::
Tensor
cpu_dbias
;
FusedBackward
(
*
ctx
,
&
cpu_dx
,
&
cpu_dz
,
&
cpu_dscale
,
&
cpu_dbias
);
CheckOutput
<
T
>
(
"DX"
,
cpu_dx
,
cpu_dx_base
,
diff
,
is_relative_atol
);
CheckOutput
<
T
>
(
"DZ"
,
cpu_dz
,
cpu_dz_base
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"DScale"
,
cpu_dscale
,
cpu_dscale_base
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
"DBias"
,
cpu_dbias
,
cpu_dbias_base
,
diff
,
is_relative_atol
);
}
private:
void
SetUp
()
{
InitRandomTensor
<
T
>
({
batch_size_
,
height_
,
width_
,
channels_
},
&
cpu_x_
);
InitRandomTensor
<
float
>
({
channels_
},
&
cpu_bn_scale_x_
);
InitRandomTensor
<
float
>
({
channels_
},
&
cpu_bn_bias_x_
);
if
(
has_shortcut_
)
{
InitRandomTensor
<
T
>
({
batch_size_
,
height_
,
width_
,
channels_
},
&
cpu_z_
);
InitRandomTensor
<
float
>
({
channels_
},
&
cpu_bn_scale_z_
);
InitRandomTensor
<
float
>
({
channels_
},
&
cpu_bn_bias_z_
);
}
else
{
if
(
fuse_add_
)
{
InitRandomTensor
<
T
>
({
batch_size_
,
height_
,
width_
,
channels_
},
&
cpu_z_
);
}
}
InitRandomTensor
<
T
>
({
batch_size_
,
height_
,
width_
,
channels_
},
&
cpu_dy_
);
}
void
InitMeanVar
(
Tensor
*
cpu_mean
,
Tensor
*
cpu_var
,
Tensor
*
cpu_saved_mean
,
Tensor
*
cpu_saved_var
)
{
InitConstantTensor
<
float
>
({
channels_
},
static_cast
<
float
>
(
0.0
f
),
cpu_mean
);
InitConstantTensor
<
float
>
({
channels_
},
static_cast
<
float
>
(
1.0
f
),
cpu_var
);
InitConstantTensor
<
float
>
({
channels_
},
static_cast
<
float
>
(
0.0
f
),
cpu_saved_mean
);
InitConstantTensor
<
float
>
({
channels_
},
static_cast
<
float
>
(
0.0
f
),
cpu_saved_var
);
}
void
BaselineForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
Tensor
*
cpu_mean_x
,
Tensor
*
cpu_var_x
,
Tensor
*
cpu_saved_mean_x
,
Tensor
*
cpu_saved_var_x
,
Tensor
*
cpu_y
,
Tensor
*
saved_reserve_space_x
,
Tensor
*
cpu_mean_z
=
nullptr
,
Tensor
*
cpu_var_z
=
nullptr
,
Tensor
*
cpu_saved_mean_z
=
nullptr
,
Tensor
*
cpu_saved_var_z
=
nullptr
,
Tensor
*
saved_reserve_space_z
=
nullptr
)
{
InitMeanVar
(
cpu_mean_x
,
cpu_var_x
,
cpu_saved_mean_x
,
cpu_saved_var_x
);
ComputeBatchNormForward
(
ctx
,
cpu_x_
,
cpu_bn_scale_x_
,
cpu_bn_bias_x_
,
cpu_mean_x
,
cpu_var_x
,
cpu_saved_mean_x
,
cpu_saved_var_x
,
cpu_y
,
saved_reserve_space_x
);
if
(
has_shortcut_
)
{
framework
::
Tensor
cpu_z_out
;
InitMeanVar
(
cpu_mean_z
,
cpu_var_z
,
cpu_saved_mean_z
,
cpu_saved_var_z
);
ComputeBatchNormForward
(
ctx
,
cpu_z_
,
cpu_bn_scale_z_
,
cpu_bn_bias_z_
,
cpu_mean_z
,
cpu_var_z
,
cpu_saved_mean_z
,
cpu_saved_var_z
,
&
cpu_z_out
,
saved_reserve_space_z
);
ComputeInplaceAdd
<
T
>
(
cpu_z_out
,
cpu_y
);
}
else
{
if
(
fuse_add_
)
{
ComputeInplaceAdd
<
T
>
(
cpu_z_
,
cpu_y
);
}
}
if
(
act_type_
==
"relu"
)
{
ComputeInplaceRelu
<
T
>
(
cpu_y
);
}
}
void
BaselineForwardFusedBNAddRelu
(
const
platform
::
CUDADeviceContext
&
ctx
,
Tensor
*
cpu_mean
,
Tensor
*
cpu_var
,
Tensor
*
cpu_saved_mean
,
Tensor
*
cpu_saved_var
,
Tensor
*
cpu_y
,
Tensor
*
saved_reserve_space
)
{
InitMeanVar
(
cpu_mean
,
cpu_var
,
cpu_saved_mean
,
cpu_saved_var
);
ComputeFusedBNAddReluForward
(
ctx
,
cpu_x_
,
cpu_z_
,
cpu_bn_scale_x_
,
cpu_bn_bias_x_
,
cpu_mean
,
cpu_var
,
cpu_saved_mean
,
cpu_saved_var
,
cpu_y
,
saved_reserve_space
);
}
void
BaselineBackwardFusedBNAddRelu
(
const
platform
::
CUDADeviceContext
&
ctx
,
Tensor
*
cpu_dx
,
Tensor
*
cpu_dz
,
Tensor
*
cpu_dscale
,
Tensor
*
cpu_dbias
)
{
ComputeFusedBNAddReluBackward
(
ctx
,
cpu_dy_
,
cpu_x_
,
cpu_bn_scale_x_
,
cpu_bn_bias_x_
,
cpu_saved_mean_base_x_
,
cpu_saved_var_base_x_
,
cpu_y_base_
,
saved_reserve_space_x_
,
cpu_dx
,
cpu_dz
,
cpu_dscale
,
cpu_dbias
);
}
void
ComputeFusedBNStatsFinalize
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_x
,
const
Tensor
&
cpu_bn_scale
,
const
Tensor
&
cpu_bn_bias
,
Tensor
*
sum
,
Tensor
*
sum_of_square
,
Tensor
*
bn_scale
,
Tensor
*
bn_bias
,
Tensor
*
mean
,
Tensor
*
var
,
Tensor
*
saved_mean
,
Tensor
*
saved_var
,
Tensor
*
equiv_scale
,
Tensor
*
equiv_bias
)
{
framework
::
Tensor
cpu_sum
;
framework
::
Tensor
cpu_sum_of_square
;
ComputeSumAndSquareSum
<
T
>
(
cpu_x
,
&
cpu_sum
,
&
cpu_sum_of_square
);
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_sum
,
place
,
sum
);
TensorCopySync
(
cpu_sum_of_square
,
place
,
sum_of_square
);
TensorCopySync
(
cpu_bn_scale
,
place
,
bn_scale
);
TensorCopySync
(
cpu_bn_bias
,
place
,
bn_bias
);
bn_scale
->
Resize
({
1
,
1
,
1
,
channels_
});
bn_bias
->
Resize
({
1
,
1
,
1
,
channels_
});
// input
mean
->
Resize
({
1
,
1
,
1
,
channels_
});
var
->
Resize
({
1
,
1
,
1
,
channels_
});
// output
equiv_scale
->
Resize
({
1
,
1
,
1
,
channels_
});
equiv_bias
->
Resize
({
1
,
1
,
1
,
channels_
});
saved_mean
->
Resize
({
1
,
1
,
1
,
channels_
});
saved_var
->
Resize
({
1
,
1
,
1
,
channels_
});
auto
param_shape
=
framework
::
vectorize
<
int
>
(
bn_scale
->
dims
());
op
::
CudnnBNStatsFinalize
<
T
>
bn_op
(
ctx
,
param_shape
);
bn_op
.
Forward
(
ctx
,
*
sum
,
*
sum_of_square
,
*
bn_scale
,
*
bn_bias
,
saved_mean
,
saved_var
,
mean
,
var
,
equiv_scale
,
equiv_bias
,
eps_
,
momentum_
,
ele_count_
,
true
);
}
// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
void
FusedForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
Tensor
*
cpu_mean_x
,
Tensor
*
cpu_var_x
,
Tensor
*
cpu_saved_mean_x
,
Tensor
*
cpu_saved_var_x
,
Tensor
*
cpu_y
,
Tensor
*
cpu_bitmask
,
Tensor
*
cpu_mean_z
=
nullptr
,
Tensor
*
cpu_var_z
=
nullptr
,
Tensor
*
cpu_saved_mean_z
=
nullptr
,
Tensor
*
cpu_saved_var_z
=
nullptr
)
{
framework
::
Tensor
x
;
framework
::
Tensor
sum_x
;
framework
::
Tensor
sum_of_square_x
;
framework
::
Tensor
bn_scale_x
;
framework
::
Tensor
bn_bias_x
;
framework
::
Tensor
z
;
framework
::
Tensor
sum_z
;
framework
::
Tensor
sum_of_square_z
;
framework
::
Tensor
bn_scale_z
;
framework
::
Tensor
bn_bias_z
;
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_x_
,
place
,
&
x
);
if
(
fuse_add_
||
has_shortcut_
)
{
TensorCopySync
(
cpu_z_
,
place
,
&
z
);
}
framework
::
Tensor
mean_x
;
framework
::
Tensor
var_x
;
framework
::
Tensor
saved_mean_x
;
framework
::
Tensor
saved_var_x
;
framework
::
Tensor
equiv_scale_x
;
framework
::
Tensor
equiv_bias_x
;
framework
::
Tensor
mean_z
;
framework
::
Tensor
var_z
;
framework
::
Tensor
saved_mean_z
;
framework
::
Tensor
saved_var_z
;
framework
::
Tensor
equiv_scale_z
;
framework
::
Tensor
equiv_bias_z
;
framework
::
Tensor
y
;
framework
::
Tensor
bitmask
;
InitMeanVar
(
cpu_mean_x
,
cpu_var_x
,
cpu_saved_mean_x
,
cpu_saved_var_x
);
TensorCopySync
(
*
cpu_mean_x
,
place
,
&
mean_x
);
TensorCopySync
(
*
cpu_var_x
,
place
,
&
var_x
);
if
(
has_shortcut_
)
{
InitMeanVar
(
cpu_mean_z
,
cpu_var_z
,
cpu_saved_mean_z
,
cpu_saved_var_z
);
TensorCopySync
(
*
cpu_mean_z
,
place
,
&
mean_z
);
TensorCopySync
(
*
cpu_var_z
,
place
,
&
var_z
);
}
// 1. BN Stats Finalize
ComputeFusedBNStatsFinalize
(
ctx
,
cpu_x_
,
cpu_bn_scale_x_
,
cpu_bn_bias_x_
,
&
sum_x
,
&
sum_of_square_x
,
&
bn_scale_x
,
&
bn_bias_x
,
&
mean_x
,
&
var_x
,
&
saved_mean_x
,
&
saved_var_x
,
&
equiv_scale_x
,
&
equiv_bias_x
);
if
(
has_shortcut_
)
{
ComputeFusedBNStatsFinalize
(
ctx
,
cpu_z_
,
cpu_bn_scale_z_
,
cpu_bn_bias_z_
,
&
sum_z
,
&
sum_of_square_z
,
&
bn_scale_z
,
&
bn_bias_z
,
&
mean_z
,
&
var_z
,
&
saved_mean_z
,
&
saved_var_z
,
&
equiv_scale_z
,
&
equiv_bias_z
);
}
y
.
Resize
(
framework
::
make_ddim
({
batch_size_
,
height_
,
width_
,
channels_
}));
int
c
=
channels_
;
int64_t
nhw
=
ele_count_
;
int32_t
c_int32_elems
=
((
c
+
63
)
&
~
63
)
/
32
;
int32_t
nhw_int32_elems
=
(
nhw
+
31
)
&
~
31
;
bitmask
.
Resize
(
framework
::
make_ddim
({
nhw_int32_elems
,
c_int32_elems
,
1
}));
auto
data_shape
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
param_shape
=
framework
::
vectorize
<
int
>
(
bn_scale_x
.
dims
());
auto
bitmask_shape
=
framework
::
vectorize
<
int
>
(
bitmask
.
dims
());
// 2. Scale Bias + Relu
op
::
CudnnScaleBiasAddRelu
<
T
>
sbar_op
(
ctx
,
act_type_
,
fuse_add_
,
has_shortcut_
,
data_shape
,
param_shape
,
bitmask_shape
);
sbar_op
.
Forward
(
ctx
,
x
,
equiv_scale_x
,
equiv_bias_x
,
&
z
,
&
equiv_scale_z
,
&
equiv_bias_z
,
&
y
,
&
bitmask
);
TensorCopySync
(
mean_x
,
platform
::
CPUPlace
(),
cpu_mean_x
);
TensorCopySync
(
var_x
,
platform
::
CPUPlace
(),
cpu_var_x
);
TensorCopySync
(
saved_mean_x
,
platform
::
CPUPlace
(),
cpu_saved_mean_x
);
TensorCopySync
(
saved_var_x
,
platform
::
CPUPlace
(),
cpu_saved_var_x
);
if
(
has_shortcut_
)
{
TensorCopySync
(
mean_z
,
platform
::
CPUPlace
(),
cpu_mean_z
);
TensorCopySync
(
var_z
,
platform
::
CPUPlace
(),
cpu_var_z
);
TensorCopySync
(
saved_mean_z
,
platform
::
CPUPlace
(),
cpu_saved_mean_z
);
TensorCopySync
(
saved_var_z
,
platform
::
CPUPlace
(),
cpu_saved_var_z
);
}
TensorCopySync
(
y
,
platform
::
CPUPlace
(),
cpu_y
);
TensorCopySync
(
bitmask
,
platform
::
CPUPlace
(),
cpu_bitmask
);
}
// Get backward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
void
FusedBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
Tensor
*
cpu_dx
,
Tensor
*
cpu_dz
,
Tensor
*
cpu_dscale
,
Tensor
*
cpu_dbias
)
{
framework
::
Tensor
dy
;
framework
::
Tensor
x
;
framework
::
Tensor
bn_scale
;
framework
::
Tensor
bn_bias
;
framework
::
Tensor
saved_mean
;
framework
::
Tensor
saved_var
;
framework
::
Tensor
bitmask
;
framework
::
Tensor
dx
;
framework
::
Tensor
dz
;
framework
::
Tensor
dscale
;
framework
::
Tensor
dbias
;
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_dy_
,
place
,
&
dy
);
TensorCopySync
(
cpu_x_
,
place
,
&
x
);
TensorCopySync
(
cpu_bn_scale_x_
,
place
,
&
bn_scale
);
TensorCopySync
(
cpu_bn_bias_x_
,
place
,
&
bn_bias
);
TensorCopySync
(
cpu_saved_mean_x_
,
place
,
&
saved_mean
);
TensorCopySync
(
cpu_saved_var_x_
,
place
,
&
saved_var
);
TensorCopySync
(
cpu_bitmask_
,
place
,
&
bitmask
);
bn_scale
.
Resize
({
1
,
1
,
1
,
channels_
});
bn_bias
.
Resize
({
1
,
1
,
1
,
channels_
});
saved_mean
.
Resize
({
1
,
1
,
1
,
channels_
});
saved_var
.
Resize
({
1
,
1
,
1
,
channels_
});
dx
.
Resize
(
framework
::
make_ddim
({
batch_size_
,
height_
,
width_
,
channels_
}));
dz
.
Resize
(
framework
::
make_ddim
({
batch_size_
,
height_
,
width_
,
channels_
}));
dscale
.
Resize
(
framework
::
make_ddim
({
1
,
1
,
1
,
channels_
}));
dbias
.
Resize
(
framework
::
make_ddim
({
1
,
1
,
1
,
channels_
}));
auto
data_shape
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
param_shape
=
framework
::
vectorize
<
int
>
(
bn_scale
.
dims
());
auto
bitmask_shape
=
framework
::
vectorize
<
int
>
(
bitmask
.
dims
());
std
::
string
act_type
=
"relu"
;
op
::
CudnnScaleBiasAddRelu
<
T
>
sbar_op
(
ctx
,
act_type
,
true
,
false
,
data_shape
,
param_shape
,
bitmask_shape
);
sbar_op
.
Backward
(
ctx
,
dy
,
x
,
bn_scale
,
bn_bias
,
saved_mean
,
saved_var
,
&
bitmask
,
&
dx
,
&
dz
,
&
dscale
,
&
dbias
,
eps_
);
TensorCopySync
(
dx
,
platform
::
CPUPlace
(),
cpu_dx
);
TensorCopySync
(
dz
,
platform
::
CPUPlace
(),
cpu_dz
);
TensorCopySync
(
dscale
,
platform
::
CPUPlace
(),
cpu_dscale
);
TensorCopySync
(
dbias
,
platform
::
CPUPlace
(),
cpu_dbias
);
}
private:
int
batch_size_
;
int
height_
;
int
width_
;
int
channels_
;
int
ele_count_
;
std
::
string
act_type_
;
bool
fuse_add_
;
bool
has_shortcut_
;
// Forward input
framework
::
Tensor
cpu_x_
;
framework
::
Tensor
cpu_bn_scale_x_
;
framework
::
Tensor
cpu_bn_bias_x_
;
framework
::
Tensor
cpu_z_
;
framework
::
Tensor
cpu_bn_scale_z_
;
framework
::
Tensor
cpu_bn_bias_z_
;
// Backward input
framework
::
Tensor
cpu_dy_
;
framework
::
Tensor
cpu_bitmask_
;
framework
::
Tensor
cpu_saved_mean_x_
;
framework
::
Tensor
cpu_saved_var_x_
;
framework
::
Tensor
cpu_saved_mean_z_
;
framework
::
Tensor
cpu_saved_var_z_
;
framework
::
Tensor
cpu_saved_mean_base_x_
;
framework
::
Tensor
cpu_saved_var_base_x_
;
framework
::
Tensor
saved_reserve_space_x_
;
framework
::
Tensor
cpu_saved_mean_base_z_
;
framework
::
Tensor
cpu_saved_var_base_z_
;
framework
::
Tensor
saved_reserve_space_z_
;
framework
::
Tensor
cpu_y_base_
;
double
eps_
=
1e-5
;
float
momentum_
=
0.9
;
};
TEST
(
CudnnBNAddReluFp16
,
BNAdd
)
{
int
batch_size
=
4
;
int
height
=
8
;
int
width
=
8
;
int
channels
=
64
;
std
::
string
act_type
=
""
;
bool
has_shortcut
=
false
;
FLAGS_cudnn_batchnorm_spatial_persistent
=
true
;
for
(
auto
fuse_add
:
{
false
,
true
})
{
CudnnBNAddReluTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
channels
,
act_type
,
fuse_add
,
has_shortcut
);
test
.
CheckForward
(
2e-3
);
}
}
TEST
(
CudnnBNAddReluFp16
,
BNAddRelu
)
{
int
batch_size
=
4
;
int
height
=
8
;
int
width
=
8
;
int
channels
=
64
;
std
::
string
act_type
=
"relu"
;
bool
has_shortcut
=
false
;
FLAGS_cudnn_batchnorm_spatial_persistent
=
true
;
for
(
auto
fuse_add
:
{
false
,
true
})
{
CudnnBNAddReluTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
channels
,
act_type
,
fuse_add
,
has_shortcut
);
test
.
CheckForward
(
2e-3
);
if
(
fuse_add
)
{
test
.
CheckBackward
(
2e-4
);
}
}
}
TEST
(
CudnnBNAddReluFp16
,
HasShortcut
)
{
int
batch_size
=
4
;
int
height
=
8
;
int
width
=
8
;
int
channels
=
64
;
std
::
string
act_type
=
""
;
bool
fuse_add
=
false
;
bool
has_shortcut
=
true
;
FLAGS_cudnn_batchnorm_spatial_persistent
=
true
;
CudnnBNAddReluTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
channels
,
act_type
,
fuse_add
,
has_shortcut
);
test
.
CheckForward
(
5e-3
);
}
paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
namespace
dynload
=
platform
::
dynload
;
template
<
typename
T
>
using
BatchNormParamType
=
typename
platform
::
CudnnDataType
<
T
>::
BatchNormParamType
;
#if CUDNN_VERSION >= 8000
template
<
typename
T
>
struct
BNStatsFinalizeArgs
{
BNStatsFinalizeArgs
()
{
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
param_dtype
=
platform
::
CudnnDataType
<
BatchNormParamType
<
T
>>::
type
;
format
=
CUDNN_TENSOR_NHWC
;
}
void
Set
(
const
std
::
vector
<
int
>
&
param_shape
)
{
PADDLE_ENFORCE_EQ
(
param_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s]."
,
param_shape
.
size
(),
framework
::
make_ddim
(
param_shape
)));
in_desc
.
set
(
param_shape
,
format
,
param_dtype
);
out_desc
.
set
(
param_shape
,
format
,
dtype
);
}
cudnnDataType_t
dtype
;
cudnnDataType_t
param_dtype
;
cudnnTensorFormat_t
format
;
platform
::
TensorDescriptor
in_desc
;
platform
::
TensorDescriptor
out_desc
;
};
template
<
typename
T
>
class
CudnnBNStatsFinalize
{
public:
CudnnBNStatsFinalize
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
int
>
&
param_shape
)
:
train_op_
(
CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING
),
inference_op_
(
CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE
)
{
args_
.
Set
(
param_shape
);
}
~
CudnnBNStatsFinalize
()
{}
void
Forward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
sum
,
const
Tensor
&
sum_of_squares
,
const
Tensor
&
scale
,
const
Tensor
&
bias
,
Tensor
*
saved_mean
,
Tensor
*
saved_invstd
,
Tensor
*
running_mean
,
Tensor
*
running_var
,
Tensor
*
equiv_scale
,
Tensor
*
equiv_bias
,
double
eps
,
float
momentum
,
int64_t
ele_count
,
bool
is_train
)
{
auto
place
=
ctx
.
GetPlace
();
if
(
is_train
)
{
TrainInit
(
ctx
);
}
else
{
InferenceInit
(
ctx
);
}
auto
&
op
=
is_train
?
train_op_
:
inference_op_
;
// Set variant_param for both inference_op_ and train_op_
float
*
sum_ptr
=
const_cast
<
float
*>
(
sum
.
data
<
float
>
());
float
*
sum_of_squares_ptr
=
const_cast
<
float
*>
(
sum_of_squares
.
data
<
float
>
());
float
*
scale_ptr
=
const_cast
<
float
*>
(
scale
.
data
<
float
>
());
float
*
bias_ptr
=
const_cast
<
float
*>
(
bias
.
data
<
float
>
());
float
*
saved_mean_ptr
=
saved_mean
->
mutable_data
<
float
>
(
place
);
float
*
saved_invstd_ptr
=
saved_invstd
->
mutable_data
<
float
>
(
place
);
float
*
running_mean_ptr
=
running_mean
->
mutable_data
<
float
>
(
place
);
float
*
running_var_ptr
=
running_var
->
mutable_data
<
float
>
(
place
);
T
*
equiv_scale_ptr
=
equiv_scale
->
mutable_data
<
T
>
(
place
);
T
*
equiv_bias_ptr
=
equiv_bias
->
mutable_data
<
T
>
(
place
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SCALE
,
scale_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_BIAS
,
bias_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_RUNNING_MEAN
,
running_mean_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_RUNNING_VAR
,
running_var_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_EQSCALE
,
equiv_scale_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_EQBIAS
,
equiv_bias_ptr
);
op
.
SetOpVariantParamAttrPtr
<
double
>
(
CUDNN_SCALAR_DOUBLE_BN_EPSILON
,
&
eps
);
// Set extra variant_param only for train_op_:
if
(
is_train
)
{
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YSUM
,
sum_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YSQSUM
,
sum_of_squares_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SAVED_MEAN
,
saved_mean_ptr
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SAVED_INVSTD
,
saved_invstd_ptr
);
double
avg_factor
=
1.0
-
momentum
;
op
.
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT
,
&
ele_count
);
op
.
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR
,
&
avg_factor
);
}
// fused op execute
auto
handle
=
ctx
.
cudnn_handle
();
op
.
Execute
(
handle
);
}
private:
void
TrainInit
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
// Set constant_param for train op
train_op_
.
SetOpConstParamAttr
(
{
CUDNN_PARAM_YSUM_PLACEHOLDER
,
CUDNN_PARAM_YSQSUM_PLACEHOLDER
,
CUDNN_PARAM_BN_SCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_BIAS_PLACEHOLDER
,
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER
,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER
,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER
,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER
,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
// Set input and output desc for train op
train_op_
.
SetOpConstParamDesc
(
{
CUDNN_PARAM_YSTATS_DESC
,
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC
},
args_
.
in_desc
.
desc
());
train_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_EQSCALEBIAS_DESC
,
args_
.
out_desc
.
desc
());
// Get workspace
auto
handle
=
ctx
.
cudnn_handle
();
train_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
// Check workspace size, also creates plan.
size_t
workspace_size_bytes
=
train_op_
.
GetWorkspaceSizeInBytes
(
handle
);
PADDLE_ENFORCE_EQ
(
workspace_size_bytes
,
0U
,
platform
::
errors
::
InvalidArgument
(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."
));
train_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
static_cast
<
void
*>
(
nullptr
));
train_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
&
workspace_size_bytes
);
}
void
InferenceInit
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
// Set constant_param for inference op
inference_op_
.
SetOpConstParamAttr
(
{
CUDNN_PARAM_BN_SCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_BIAS_PLACEHOLDER
,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER
,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER
,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
// Set input and output desc for inference op
inference_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC
,
args_
.
in_desc
.
desc
());
inference_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_EQSCALEBIAS_DESC
,
args_
.
out_desc
.
desc
());
// Get workspace
auto
handle
=
ctx
.
cudnn_handle
();
inference_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
// Check workspace size, also creates plan.
size_t
workspace_size_bytes
=
inference_op_
.
GetWorkspaceSizeInBytes
(
handle
);
PADDLE_ENFORCE_EQ
(
workspace_size_bytes
,
0U
,
platform
::
errors
::
InvalidArgument
(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."
));
inference_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
static_cast
<
void
*>
(
nullptr
));
inference_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
&
workspace_size_bytes
);
}
BNStatsFinalizeArgs
<
T
>
args_
;
CudnnFusionOp
train_op_
;
CudnnFusionOp
inference_op_
;
};
#endif
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/cudnn_fusion_helper.h
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
dynload
=
platform
::
dynload
;
#if CUDNN_VERSION >= 8000
// A wrapper for cuDNN fused_op API.
class
CudnnFusionOp
{
public:
explicit
CudnnFusionOp
(
cudnnFusedOps_t
op_id
)
:
plan_created_
(
false
)
{
// New 'fused op' descriptor creation
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateFusedOpsPlan
(
&
op_
,
op_id
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateFusedOpsConstParamPack
(
&
op_const_params_
,
op_id
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateFusedOpsVariantParamPack
(
&
op_variant_params_
,
op_id
));
}
~
CudnnFusionOp
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyFusedOpsVariantParamPack
(
op_variant_params_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyFusedOpsConstParamPack
(
op_const_params_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyFusedOpsPlan
(
op_
));
}
// Execute fused op
void
Execute
(
cudnnHandle_t
cudnn_handle
)
{
PADDLE_ENFORCE_EQ
(
plan_created_
,
true
,
platform
::
errors
::
Fatal
(
"CudnnFusionOp exec requested without a valid 'plan', need: "
"<set const params>, GetWorkspaceSizeBytes(), Execute()."
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnFusedOpsExecute
(
cudnn_handle
,
op_
,
op_variant_params_
));
}
// Set const param pack attribute given a descriptor.
template
<
typename
T
>
void
SetOpConstParamDesc
(
cudnnFusedOpsConstParamLabel_t
param_label
,
T
*
param_ptr
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetFusedOpsConstParamPackAttribute
(
op_const_params_
,
param_label
,
param_ptr
));
plan_created_
=
false
;
}
// Set multiple const param pack attribute given a descriptor.
template
<
typename
T
>
void
SetOpConstParamDesc
(
const
std
::
vector
<
cudnnFusedOpsConstParamLabel_t
>
&
param_labels
,
T
*
param_ptr
)
{
for
(
auto
param_label
:
param_labels
)
{
SetOpConstParamDesc
(
param_label
,
param_ptr
);
}
}
// Set const param pack attribute given a value of param.
template
<
typename
T
>
void
SetOpConstParamAttr
(
cudnnFusedOpsConstParamLabel_t
param_label
,
T
param
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetFusedOpsConstParamPackAttribute
(
op_const_params_
,
param_label
,
&
param
));
plan_created_
=
false
;
}
// Set multiple const param pack attribute given a value of param.
template
<
typename
T
>
void
SetOpConstParamAttr
(
const
std
::
vector
<
cudnnFusedOpsConstParamLabel_t
>
&
param_labels
,
T
param
)
{
for
(
auto
param_label
:
param_labels
)
{
SetOpConstParamAttr
(
param_label
,
param
);
}
}
// Set a variant param pack attribute given a reference to a param.
template
<
typename
T
>
void
SetOpVariantParamAttrPtr
(
cudnnFusedOpsVariantParamLabel_t
param_label
,
T
*
param_ptr
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetFusedOpsVariantParamPackAttribute
(
op_variant_params_
,
param_label
,
param_ptr
));
}
// Set multiple const param pack attributes given a reference to a param.
template
<
typename
T
>
void
SetOpVariantParamAttrPtr
(
const
std
::
vector
<
cudnnFusedOpsVariantParamLabel_t
>
&
param_labels
,
const
T
*
param_ptr
)
{
for
(
auto
param_label
:
param_labels
)
{
SetOpVariantParamAttrPtr
(
param_label
,
param_ptr
);
}
}
// Get the workspace, which is required before Execute().
size_t
GetWorkspaceSizeInBytes
(
cudnnHandle_t
cudnn_handle
)
{
if
(
!
plan_created_
)
{
workspace_bytes_
=
0U
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnMakeFusedOpsPlan
(
cudnn_handle
,
op_
,
op_const_params_
,
&
workspace_bytes_
));
plan_created_
=
true
;
}
return
workspace_bytes_
;
}
private:
bool
plan_created_
;
size_t
workspace_bytes_
;
cudnnFusedOpsPlan_t
op_
;
cudnnFusedOpsConstParamPack_t
op_const_params_
;
cudnnFusedOpsVariantParamPack_t
op_variant_params_
;
};
class
CudnnFusionOpCache
{
public:
static
CudnnFusionOpCache
&
Instance
()
{
static
CudnnFusionOpCache
instance
;
return
instance
;
}
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
*
GetForward
()
{
return
&
forward_cache_
;
}
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
*
GetBackward
()
{
return
&
backward_cache_
;
}
private:
CudnnFusionOpCache
()
{}
~
CudnnFusionOpCache
()
{
// Need to delete the memory of cache.
}
CudnnFusionOpCache
(
const
CudnnFusionOpCache
&
)
{}
private:
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
forward_cache_
;
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
backward_cache_
;
};
#endif // CUDNN_VERSION >= 8000
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/cudnn_norm_conv.cu.h
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
namespace
dynload
=
platform
::
dynload
;
template
<
typename
T
>
using
ScalingParamType
=
typename
platform
::
CudnnDataType
<
T
>::
ScalingParamType
;
#if CUDNN_VERSION >= 8000
static
size_t
RoundUp
(
int64_t
a
,
int64_t
b
)
{
return
(
a
+
b
-
1
)
/
b
*
b
;
}
template
<
typename
T
>
struct
NormConvolutionArgs
{
NormConvolutionArgs
()
{
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
format
=
CUDNN_TENSOR_NHWC
;
compute_type
=
platform
::
CudnnDataType
<
float
>::
type
;
}
void
Set
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
int
>
&
input_shape
,
const
std
::
vector
<
int
>
&
filter_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
int
padding
,
int
stride
,
int
dilation
,
int
group
)
{
PADDLE_ENFORCE_EQ
(
input_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of input_shape is expected to 4. But recieved "
"input_shape's size is %d, input_shape is [%s]."
,
input_shape
.
size
(),
framework
::
make_ddim
(
input_shape
)));
PADDLE_ENFORCE_EQ
(
filter_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of filter_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s]."
,
filter_shape
.
size
(),
framework
::
make_ddim
(
filter_shape
)));
PADDLE_ENFORCE_EQ
(
filter_shape
[
1
]
==
filter_shape
[
2
]
&&
(
filter_shape
[
1
]
==
1
||
filter_shape
[
1
]
==
3
),
true
,
platform
::
errors
::
InvalidArgument
(
"The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s]."
,
framework
::
make_ddim
(
filter_shape
)));
PADDLE_ENFORCE_EQ
((
filter_shape
[
0
]
%
32
==
0
&&
filter_shape
[
3
]
%
8
==
0
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input channel is expected to be multiple of 8, "
"and the output channel is expected to be multiple "
"of 32. But recieved input channel is %d, output "
"channel is %d."
,
filter_shape
[
3
],
filter_shape
[
0
]));
PADDLE_ENFORCE_EQ
(
output_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s]."
,
output_shape
.
size
(),
framework
::
make_ddim
(
output_shape
)));
is_support
=
IsSupport
(
ctx
,
filter_shape
,
stride
,
dilation
,
group
);
PADDLE_ENFORCE_EQ
(
is_support
,
true
,
platform
::
errors
::
InvalidArgument
(
"Current test is only supported in the platforms with "
"compatiblity greater than or equal to 70 and the kernel size "
"must be equal to 1 or 3. When the kernel size is 1, "
"the stride must be 1 if the compatiblity is equal to 70. "
"Besides, the dilation and group must be equal to 1. But recieved "
"compatiblity is %d, kernel size is %d, stride is %d, "
"dilation is %d, group is %d"
,
ctx
.
GetComputeCapability
(),
filter_shape
[
1
],
stride
,
dilation
,
group
));
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
in_dims
.
push_back
(
input_shape
[
i
]);
}
for
(
size_t
i
=
0
;
i
<
filter_shape
.
size
();
++
i
)
{
filter_dims
.
push_back
(
filter_shape
[
i
]);
}
paddings
=
{
padding
,
padding
};
strides
=
{
stride
,
stride
};
dilations
=
{
dilation
,
dilation
};
in_desc
.
set
(
input_shape
,
format
,
dtype
);
filter_desc
.
set
(
filter_shape
,
format
,
dtype
,
group
);
out_desc
.
set
(
output_shape
,
format
,
dtype
);
int
output_channel
=
filter_shape
[
0
];
std
::
vector
<
int
>
stats_shape
=
{
1
,
1
,
1
,
output_channel
};
out_stats_desc
.
set
(
stats_shape
,
format
,
compute_type
);
conv_desc
.
set
(
dtype
,
paddings
,
strides
,
dilations
,
false
,
group
);
}
bool
IsSupport
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
int
>
&
filter_shape
,
int
stride
,
int
dilation
,
int
group
)
{
int
kernel_size
=
filter_shape
[
1
];
if
(
dilation
!=
1
||
group
!=
1
)
{
return
false
;
}
if
(
ctx
.
GetComputeCapability
()
==
70
)
{
if
((
kernel_size
==
3
)
||
((
kernel_size
==
1
)
&&
(
stride
==
1
)))
{
return
true
;
}
}
else
if
(
ctx
.
GetComputeCapability
()
>
70
)
{
if
((
kernel_size
==
3
)
||
(
kernel_size
==
1
))
{
return
true
;
}
}
return
false
;
}
cudnnDataType_t
dtype
;
cudnnTensorFormat_t
format
;
cudnnDataType_t
compute_type
;
std
::
vector
<
int64_t
>
in_dims
;
std
::
vector
<
int64_t
>
filter_dims
;
std
::
vector
<
int
>
strides
;
std
::
vector
<
int
>
paddings
;
std
::
vector
<
int
>
dilations
;
platform
::
TensorDescriptor
in_desc
;
platform
::
FilterDescriptor
filter_desc
;
platform
::
TensorDescriptor
out_desc
;
platform
::
TensorDescriptor
out_stats_desc
;
platform
::
ConvolutionDescriptor
conv_desc
;
bool
is_support
;
};
template
<
typename
T
>
class
CudnnNormConvolution
{
public:
CudnnNormConvolution
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
int
>
&
input_shape
,
const
std
::
vector
<
int
>
&
filter_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
const
int
&
padding
,
const
int
&
stride
,
const
int
&
dilation
,
const
int
&
group
)
{
args_
.
Set
(
ctx
,
input_shape
,
filter_shape
,
output_shape
,
padding
,
stride
,
dilation
,
group
);
}
~
CudnnNormConvolution
()
{}
void
Forward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
const
Tensor
&
filter
,
Tensor
*
output
,
Tensor
*
sum
,
Tensor
*
sum_of_squares
)
{
auto
cudnn_handle
=
ctx
.
cudnn_handle
();
auto
place
=
ctx
.
GetPlace
();
CudnnFusionOp
*
fwd_op
=
GetForwardOp
(
ctx
);
size_t
workspace_size
=
RoundUp
(
static_cast
<
int64_t
>
(
fwd_op
->
GetWorkspaceSizeInBytes
(
cudnn_handle
)),
512
);
// Set variant_param
// input ptr
T
*
input_ptr
=
const_cast
<
T
*>
(
input
.
data
<
T
>
());
T
*
filter_ptr
=
const_cast
<
T
*>
(
filter
.
data
<
T
>
());
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_XDATA
,
input_ptr
);
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WDATA
,
filter_ptr
);
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES
,
&
workspace_size
);
// output ptr
T
*
output_ptr
=
output
->
mutable_data
<
T
>
(
place
);
float
*
sum_ptr
=
sum
->
mutable_data
<
float
>
(
place
);
float
*
sum_of_squares_ptr
=
sum_of_squares
->
mutable_data
<
float
>
(
place
);
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YDATA
,
output_ptr
);
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YSUM
,
sum_ptr
);
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YSQSUM
,
sum_of_squares_ptr
);
ctx
.
cudnn_workspace_handle
().
RunFunc
(
[
&
](
void
*
workspace_ptr
)
{
// workspace ptr
fwd_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
workspace_ptr
);
// fused op execute
fwd_op
->
Execute
(
cudnn_handle
);
},
workspace_size
);
}
private:
CudnnFusionOp
*
GetForwardOp
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
&
cache
=
*
(
CudnnFusionOpCache
::
Instance
().
GetForward
());
CudnnFusionOp
*
fwd_op
=
cache
.
GetAlgorithm
(
args_
.
in_dims
,
args_
.
filter_dims
,
args_
.
strides
,
args_
.
paddings
,
args_
.
dilations
,
0
,
static_cast
<
int64_t
>
(
args_
.
dtype
),
[
&
]()
{
CudnnFusionOp
*
fwd_op
=
new
CudnnFusionOp
(
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS
);
// Set constant_param
fwd_op
->
SetOpConstParamAttr
(
{
CUDNN_PARAM_XDATA_PLACEHOLDER
,
CUDNN_PARAM_WDATA_PLACEHOLDER
,
CUDNN_PARAM_YDATA_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
fwd_op
->
SetOpConstParamAttr
(
{
CUDNN_PARAM_YSUM_PLACEHOLDER
,
CUDNN_PARAM_YSQSUM_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
// conv desc
fwd_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_CONV_DESC
,
args_
.
conv_desc
.
desc
());
// input desc
fwd_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_XDESC
,
args_
.
in_desc
.
desc
());
// filter desc
fwd_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_WDESC
,
args_
.
filter_desc
.
desc
());
// output desc
fwd_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_YDESC
,
args_
.
out_desc
.
desc
());
// output_stats desc
fwd_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_YSTATS_DESC
,
args_
.
out_stats_desc
.
desc
());
// batch_norm mode
fwd_op
->
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
// Make cudnn fused ops plan
fwd_op
->
GetWorkspaceSizeInBytes
(
ctx
.
cudnn_handle
());
return
fwd_op
;
});
return
fwd_op
;
}
private:
NormConvolutionArgs
<
T
>
args_
;
};
template
<
typename
T
>
class
CudnnNormConvolutionGrad
{
public:
CudnnNormConvolutionGrad
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
int
>
&
input_shape
,
const
std
::
vector
<
int
>
&
filter_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
const
int
&
padding
,
const
int
&
stride
,
const
int
&
dilation
,
const
int
&
group
)
{
args_
.
Set
(
ctx
,
input_shape
,
filter_shape
,
output_shape
,
padding
,
stride
,
dilation
,
group
);
dgrad_algo_
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
}
~
CudnnNormConvolutionGrad
()
{}
void
Backward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
const
Tensor
&
filter
,
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
Tensor
*
filter_grad
,
bool
use_addto
=
false
)
{
auto
place
=
ctx
.
GetPlace
();
T
*
input_ptr
=
const_cast
<
T
*>
(
input
.
data
<
T
>
());
T
*
filter_ptr
=
const_cast
<
T
*>
(
filter
.
data
<
T
>
());
T
*
output_grad_ptr
=
const_cast
<
T
*>
(
output_grad
.
data
<
T
>
());
if
(
filter_grad
)
{
T
*
filter_grad_ptr
=
filter_grad
->
mutable_data
<
T
>
(
place
);
BackwardFilter
(
ctx
,
output_grad_ptr
,
input_ptr
,
filter_grad_ptr
);
}
if
(
input_grad
)
{
T
*
input_grad_ptr
=
input_grad
->
mutable_data
<
T
>
(
place
);
BackwardData
(
ctx
,
output_grad_ptr
,
filter_ptr
,
input_grad_ptr
,
use_addto
);
}
}
private:
void
BackwardFilter
(
const
platform
::
CUDADeviceContext
&
ctx
,
T
*
output_grad_ptr
,
T
*
input_ptr
,
T
*
filter_grad_ptr
)
{
auto
cudnn_handle
=
ctx
.
cudnn_handle
();
CudnnFusionOp
*
wgrad_op
=
GetBackwardFilterOp
(
ctx
);
size_t
workspace_size
=
RoundUp
(
static_cast
<
int64_t
>
(
wgrad_op
->
GetWorkspaceSizeInBytes
(
cudnn_handle
)),
512
);
wgrad_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_XDATA
,
input_ptr
);
wgrad_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_DYDATA
,
output_grad_ptr
);
wgrad_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_DWDATA
,
filter_grad_ptr
);
wgrad_op
->
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES
,
&
workspace_size
);
ctx
.
cudnn_workspace_handle
().
RunFunc
(
[
&
](
void
*
workspace_ptr
)
{
// workspace ptr
wgrad_op
->
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
workspace_ptr
);
// fused op execute
wgrad_op
->
Execute
(
cudnn_handle
);
},
workspace_size
);
}
void
BackwardData
(
const
platform
::
CUDADeviceContext
&
ctx
,
T
*
output_grad_ptr
,
T
*
filter_ptr
,
T
*
input_grad_ptr
,
bool
use_addto
=
false
)
{
auto
cudnn_handle
=
ctx
.
cudnn_handle
();
size_t
workspace_size
=
GetWorkspaceSizeBwdData
(
ctx
);
// Convolution dgrad followed optionally by batchnorm dgrad
ScalingParamType
<
T
>
alpha
=
1.0
f
;
ScalingParamType
<
T
>
beta
=
use_addto
?
1.0
f
:
0.0
f
;
ctx
.
cudnn_workspace_handle
().
RunFunc
(
[
&
](
void
*
cudnn_workspace_ptr
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
cudnn_handle
,
&
alpha
,
args_
.
filter_desc
.
desc
(),
filter_ptr
,
args_
.
out_desc
.
desc
(),
output_grad_ptr
,
args_
.
conv_desc
.
desc
(),
dgrad_algo_
,
cudnn_workspace_ptr
,
workspace_size
,
&
beta
,
args_
.
in_desc
.
desc
(),
input_grad_ptr
));
},
workspace_size
);
}
CudnnFusionOp
*
GetBackwardFilterOp
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
framework
::
AlgorithmsCache
<
CudnnFusionOp
*>
&
cache
=
*
(
CudnnFusionOpCache
::
Instance
().
GetBackward
());
CudnnFusionOp
*
wgrad_op
=
cache
.
GetAlgorithm
(
args_
.
in_dims
,
args_
.
filter_dims
,
args_
.
strides
,
args_
.
paddings
,
args_
.
dilations
,
0
,
static_cast
<
int64_t
>
(
args_
.
dtype
),
[
&
]()
{
CudnnFusionOp
*
wgrad_op
=
new
CudnnFusionOp
(
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD
);
wgrad_op
->
SetOpConstParamAttr
(
{
CUDNN_PARAM_DYDATA_PLACEHOLDER
,
CUDNN_PARAM_XDATA_PLACEHOLDER
,
CUDNN_PARAM_DWDATA_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
// conv desc
wgrad_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_CONV_DESC
,
args_
.
conv_desc
.
desc
());
// input desc
wgrad_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_XDESC
,
args_
.
in_desc
.
desc
());
// filter desc
wgrad_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_DWDESC
,
args_
.
filter_desc
.
desc
());
// output desc
wgrad_op
->
SetOpConstParamDesc
(
CUDNN_PARAM_DYDESC
,
args_
.
out_desc
.
desc
());
wgrad_op
->
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
// Make cudnn fused ops plan
wgrad_op
->
GetWorkspaceSizeInBytes
(
ctx
.
cudnn_handle
());
return
wgrad_op
;
});
return
wgrad_op
;
}
size_t
GetWorkspaceSizeBwdData
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
size_t
workspace_size
=
0U
;
auto
handle
=
ctx
.
cudnn_handle
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
handle
,
args_
.
filter_desc
.
desc
(),
args_
.
out_desc
.
desc
(),
args_
.
conv_desc
.
desc
(),
args_
.
in_desc
.
desc
(),
dgrad_algo_
,
&
workspace_size
));
return
RoundUp
(
workspace_size
,
512
);
}
private:
NormConvolutionArgs
<
T
>
args_
;
cudnnConvolutionBwdDataAlgo_t
dgrad_algo_
;
};
#endif
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/cudnn_norm_conv_test.cc
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
op
=
paddle
::
operators
;
using
Tensor
=
paddle
::
framework
::
Tensor
;
USE_OP
(
conv2d
);
USE_OP
(
conv2d_grad
);
USE_OP_DEVICE_KERNEL
(
conv2d
,
CUDNN
);
USE_OP_DEVICE_KERNEL
(
conv2d_grad
,
CUDNN
);
template
<
typename
T
>
void
InitRandomTensor
(
const
std
::
vector
<
int64_t
>
&
dims
,
framework
::
Tensor
*
cpu_out
)
{
T
*
cpu_out_ptr
=
cpu_out
->
mutable_data
<
T
>
(
framework
::
make_ddim
(
dims
),
platform
::
CPUPlace
());
std
::
default_random_engine
random
(
0
);
std
::
uniform_real_distribution
<
float
>
dis
(
0.0
,
1.0
);
for
(
int
i
=
0
;
i
<
cpu_out
->
numel
();
++
i
)
{
cpu_out_ptr
[
i
]
=
static_cast
<
T
>
(
dis
(
random
));
}
}
template
<
typename
T
>
void
TransposeNchwToNhwc
(
const
framework
::
Tensor
&
cpu_in
,
framework
::
Tensor
*
cpu_out
)
{
auto
in_dims
=
cpu_in
.
dims
();
EXPECT_EQ
(
cpu_in
.
dims
().
size
(),
4
);
const
T
*
cpu_in_ptr
=
cpu_in
.
data
<
T
>
();
T
*
cpu_out_ptr
=
cpu_out
->
mutable_data
<
T
>
(
{
in_dims
[
0
],
in_dims
[
2
],
in_dims
[
3
],
in_dims
[
1
]},
platform
::
CPUPlace
());
int64_t
n
=
in_dims
[
0
];
int64_t
c
=
in_dims
[
1
];
int64_t
hw
=
in_dims
[
2
]
*
in_dims
[
3
];
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
j
=
0
;
j
<
hw
;
++
j
)
{
for
(
int
k
=
0
;
k
<
c
;
++
k
)
{
int
dst_idx
=
i
*
hw
*
c
+
j
*
c
+
k
;
int
src_idx
=
i
*
c
*
hw
+
k
*
hw
+
j
;
cpu_out_ptr
[
dst_idx
]
=
cpu_in_ptr
[
src_idx
];
}
}
}
}
template
<
typename
T
>
void
CheckOutput
(
const
framework
::
Tensor
&
cpu_res
,
const
framework
::
Tensor
&
cpu_base
,
float
diff
,
bool
is_relative_atol
=
false
)
{
EXPECT_EQ
(
cpu_res
.
dims
(),
cpu_base
.
dims
());
const
T
*
cpu_res_ptr
=
cpu_res
.
data
<
T
>
();
const
T
*
cpu_base_ptr
=
cpu_base
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
cpu_res
.
numel
();
++
i
)
{
if
(
is_relative_atol
)
{
EXPECT_LT
(
static_cast
<
float
>
(
std
::
abs
((
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
])
/
cpu_base_ptr
[
i
])),
diff
);
}
else
{
EXPECT_LT
(
static_cast
<
float
>
(
std
::
abs
(
cpu_res_ptr
[
i
]
-
cpu_base_ptr
[
i
])),
diff
);
}
}
}
// Use Paddle conv2d op results as baseline
void
ComputeConv2DForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_input
,
const
Tensor
&
cpu_filter
,
Tensor
*
cpu_output
,
int
stride
,
int
padding
)
{
framework
::
Scope
scope
;
auto
*
input
=
scope
.
Var
(
"Input"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
filter
=
scope
.
Var
(
"Filter"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
output
=
scope
.
Var
(
"Output"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_input
,
place
,
input
);
TensorCopySync
(
cpu_filter
,
place
,
filter
);
framework
::
AttributeMap
attrs
;
bool
use_cudnn
=
true
;
std
::
string
data_format
=
"NHWC"
;
std
::
vector
<
int
>
strides
=
{
stride
,
stride
};
std
::
vector
<
int
>
paddings
=
{
padding
,
padding
};
attrs
.
insert
({
"strides"
,
strides
});
attrs
.
insert
({
"paddings"
,
paddings
});
attrs
.
insert
({
"use_cudnn"
,
use_cudnn
});
attrs
.
insert
({
"data_format"
,
data_format
});
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"conv2d"
,
{{
"Input"
,
{
"Input"
}},
{
"Filter"
,
{
"Filter"
}}},
{{
"Output"
,
{
"Output"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
TensorCopySync
(
*
output
,
platform
::
CPUPlace
(),
cpu_output
);
}
// Use Paddle conv2d_grad op results as baseline
void
ComputeConv2DBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
cpu_input
,
const
Tensor
&
cpu_filter
,
const
Tensor
&
cpu_output_grad
,
framework
::
Tensor
*
cpu_input_grad
,
framework
::
Tensor
*
cpu_filter_grad
,
int
stride
,
int
padding
,
int
dilation
)
{
framework
::
Scope
scope
;
auto
*
input
=
scope
.
Var
(
"Input"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
filter
=
scope
.
Var
(
"Filter"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
output_grad
=
scope
.
Var
(
"Output@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
input_grad
=
scope
.
Var
(
"Input@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
filter_grad
=
scope
.
Var
(
"Filter@GRAD"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_input
,
place
,
input
);
TensorCopySync
(
cpu_filter
,
place
,
filter
);
TensorCopySync
(
cpu_output_grad
,
place
,
output_grad
);
framework
::
AttributeMap
attrs
;
bool
use_cudnn
=
true
;
std
::
string
data_format
=
"NHWC"
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
std
::
vector
<
int
>
strides
=
{
stride
,
stride
};
std
::
vector
<
int
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
int
>
dilations
=
{
dilation
,
dilation
};
int
groups
=
1
;
bool
exhaustive_search
=
false
;
bool
use_addto
=
false
;
attrs
.
insert
({
"use_cudnn"
,
use_cudnn
});
attrs
.
insert
({
"data_format"
,
data_format
});
attrs
.
insert
({
"padding_algorithm"
,
padding_algorithm
});
attrs
.
insert
({
"strides"
,
strides
});
attrs
.
insert
({
"paddings"
,
paddings
});
attrs
.
insert
({
"dilations"
,
dilations
});
attrs
.
insert
({
"groups"
,
groups
});
attrs
.
insert
({
"exhaustive_search"
,
exhaustive_search
});
attrs
.
insert
({
"use_addto"
,
use_addto
});
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"conv2d_grad"
,
{{
"Input"
,
{
"Input"
}},
{
"Filter"
,
{
"Filter"
}},
{
"Output@GRAD"
,
{
"Output@GRAD"
}}},
{{
"Input@GRAD"
,
{
"Input@GRAD"
}},
{
"Filter@GRAD"
,
{
"Filter@GRAD"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
TensorCopySync
(
*
input_grad
,
platform
::
CPUPlace
(),
cpu_input_grad
);
TensorCopySync
(
*
filter_grad
,
platform
::
CPUPlace
(),
cpu_filter_grad
);
}
template
<
typename
T
>
void
ComputeSumAndSquareSum
(
const
framework
::
Tensor
&
cpu_out
,
framework
::
Tensor
*
cpu_sum
,
framework
::
Tensor
*
cpu_sum_of_square
)
{
auto
dims
=
cpu_out
.
dims
();
int64_t
c
=
dims
[
3
];
const
T
*
cpu_out_ptr
=
cpu_out
.
data
<
T
>
();
float
*
cpu_sum_ptr
=
cpu_sum
->
mutable_data
<
float
>
({
1
,
1
,
1
,
c
},
platform
::
CPUPlace
());
float
*
cpu_sum_square_ptr
=
cpu_sum_of_square
->
mutable_data
<
float
>
(
{
1
,
1
,
1
,
c
},
platform
::
CPUPlace
());
for
(
int
j
=
0
;
j
<
c
;
++
j
)
{
float
tmp_sum
=
0.0
f
;
float
tmp_sum_of_squares
=
0.0
f
;
for
(
int
i
=
0
;
i
<
cpu_out
.
numel
()
/
c
;
++
i
)
{
float
tmp_out
=
static_cast
<
float
>
(
cpu_out_ptr
[
i
*
c
+
j
]);
tmp_sum
+=
tmp_out
;
tmp_sum_of_squares
+=
tmp_out
*
tmp_out
;
}
cpu_sum_ptr
[
j
]
=
tmp_sum
;
cpu_sum_square_ptr
[
j
]
=
tmp_sum_of_squares
;
}
}
template
<
typename
T
>
class
CudnnNormConvolutionTester
{
public:
CudnnNormConvolutionTester
(
int
batch_size
,
int
height
,
int
width
,
int
input_channels
,
int
output_channels
,
int
kernel_size
,
int
stride
)
{
batch_size_
=
batch_size
;
height_
=
height
;
width_
=
width
;
input_channels_
=
input_channels
;
output_channels_
=
output_channels
;
kernel_size_
=
kernel_size
;
stride_
=
stride
;
padding_
=
(
kernel_size_
-
1
)
/
2
;
out_height_
=
(
height_
+
2
*
padding_
-
kernel_size_
)
/
stride_
+
1
;
out_width_
=
(
width_
+
2
*
padding_
-
kernel_size_
)
/
stride_
+
1
;
SetUp
();
}
~
CudnnNormConvolutionTester
()
{}
void
CheckForward
(
float
diff
,
bool
is_relative_atol
=
false
)
{
platform
::
CUDADeviceContext
*
ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
framework
::
Tensor
cpu_output_base
;
framework
::
Tensor
cpu_sum_base
;
framework
::
Tensor
cpu_sum_of_square_base
;
BaselineForward
(
*
ctx
,
&
cpu_output_base
,
&
cpu_sum_base
,
&
cpu_sum_of_square_base
);
framework
::
Tensor
cpu_output
;
framework
::
Tensor
cpu_sum
;
framework
::
Tensor
cpu_sum_of_square
;
FusedForward
(
*
ctx
,
&
cpu_output
,
&
cpu_sum
,
&
cpu_sum_of_square
);
// Check forward correctness between baseline and results of normconv.
CheckOutput
<
T
>
(
cpu_output
,
cpu_output_base
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
cpu_sum
,
cpu_sum_base
,
diff
,
is_relative_atol
);
CheckOutput
<
float
>
(
cpu_sum_of_square
,
cpu_sum_of_square_base
,
diff
,
is_relative_atol
);
}
void
CheckBackward
(
float
diff
,
bool
is_relative_atol
=
false
)
{
platform
::
CUDADeviceContext
*
ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
framework
::
Tensor
cpu_input_grad_base
;
framework
::
Tensor
cpu_filter_nchw_grad_base
;
framework
::
Tensor
cpu_filter_nhwc_grad_base
;
BaselineBackward
(
*
ctx
,
&
cpu_input_grad_base
,
&
cpu_filter_nchw_grad_base
);
TransposeNchwToNhwc
<
T
>
(
cpu_filter_nchw_grad_base
,
&
cpu_filter_nhwc_grad_base
);
framework
::
Tensor
cpu_input_grad
;
framework
::
Tensor
cpu_filter_nhwc_grad
;
FusedBackward
(
*
ctx
,
&
cpu_input_grad
,
&
cpu_filter_nhwc_grad
);
// Check backward correctness between baseline and results of normconv.
CheckOutput
<
T
>
(
cpu_input_grad
,
cpu_input_grad_base
,
diff
,
is_relative_atol
);
CheckOutput
<
T
>
(
cpu_filter_nhwc_grad
,
cpu_filter_nhwc_grad_base
,
diff
,
is_relative_atol
);
}
private:
void
SetUp
()
{
InitRandomTensor
<
T
>
({
batch_size_
,
height_
,
width_
,
input_channels_
},
&
cpu_input_
);
InitRandomTensor
<
T
>
(
{
output_channels_
,
input_channels_
,
kernel_size_
,
kernel_size_
},
&
cpu_filter_nchw_
);
// transpoes for filter, NCHW -> NHWC
TransposeNchwToNhwc
<
T
>
(
cpu_filter_nchw_
,
&
cpu_filter_nhwc_
);
InitRandomTensor
<
T
>
(
{
batch_size_
,
out_height_
,
out_width_
,
output_channels_
},
&
cpu_output_grad_
);
}
void
BaselineForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
cpu_output_base
,
framework
::
Tensor
*
cpu_sum_base
,
framework
::
Tensor
*
cpu_sum_of_square_base
)
{
ComputeConv2DForward
(
ctx
,
cpu_input_
,
cpu_filter_nchw_
,
cpu_output_base
,
stride_
,
padding_
);
ComputeSumAndSquareSum
<
T
>
(
*
cpu_output_base
,
cpu_sum_base
,
cpu_sum_of_square_base
);
}
void
BaselineBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
cpu_input_grad_base
,
framework
::
Tensor
*
cpu_filter_grad_base
)
{
ComputeConv2DBackward
(
ctx
,
cpu_input_
,
cpu_filter_nchw_
,
cpu_output_grad_
,
cpu_input_grad_base
,
cpu_filter_grad_base
,
stride_
,
padding_
,
dilation_
);
}
// get forward results of cudnn_norm_conv
void
FusedForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
cpu_output
,
framework
::
Tensor
*
cpu_sum
,
framework
::
Tensor
*
cpu_sum_of_square
)
{
framework
::
Tensor
input
;
framework
::
Tensor
filter_nhwc
;
framework
::
Tensor
output
;
framework
::
Tensor
sum
;
framework
::
Tensor
sum_of_square
;
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_input_
,
place
,
&
input
);
TensorCopySync
(
cpu_filter_nhwc_
,
place
,
&
filter_nhwc
);
output
.
Resize
(
framework
::
make_ddim
(
{
batch_size_
,
out_height_
,
out_width_
,
output_channels_
}));
sum
.
Resize
(
framework
::
make_ddim
({
1
,
1
,
1
,
output_channels_
}));
sum_of_square
.
Resize
(
framework
::
make_ddim
({
1
,
1
,
1
,
output_channels_
}));
auto
input_shape
=
framework
::
vectorize
<
int
>
(
input
.
dims
());
auto
filter_shape
=
framework
::
vectorize
<
int
>
(
filter_nhwc
.
dims
());
auto
output_shape
=
framework
::
vectorize
<
int
>
(
output
.
dims
());
op
::
CudnnNormConvolution
<
T
>
conv_op
(
ctx
,
input_shape
,
filter_shape
,
output_shape
,
padding_
,
stride_
,
dilation_
,
group_
);
conv_op
.
Forward
(
ctx
,
input
,
filter_nhwc
,
&
output
,
&
sum
,
&
sum_of_square
);
TensorCopySync
(
output
,
platform
::
CPUPlace
(),
cpu_output
);
TensorCopySync
(
sum
,
platform
::
CPUPlace
(),
cpu_sum
);
TensorCopySync
(
sum_of_square
,
platform
::
CPUPlace
(),
cpu_sum_of_square
);
}
void
FusedBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
cpu_input_grad
,
framework
::
Tensor
*
cpu_filter_grad
)
{
framework
::
Tensor
input
;
framework
::
Tensor
filter_nhwc
;
framework
::
Tensor
output_grad
;
framework
::
Tensor
input_grad
;
framework
::
Tensor
filter_grad
;
auto
place
=
ctx
.
GetPlace
();
TensorCopySync
(
cpu_input_
,
place
,
&
input
);
TensorCopySync
(
cpu_filter_nhwc_
,
place
,
&
filter_nhwc
);
TensorCopySync
(
cpu_output_grad_
,
place
,
&
output_grad
);
input_grad
.
Resize
(
input
.
dims
());
filter_grad
.
Resize
(
filter_nhwc
.
dims
());
auto
input_shape
=
framework
::
vectorize
<
int
>
(
input
.
dims
());
auto
filter_shape
=
framework
::
vectorize
<
int
>
(
filter_nhwc
.
dims
());
auto
output_shape
=
framework
::
vectorize
<
int
>
(
output_grad
.
dims
());
op
::
CudnnNormConvolutionGrad
<
T
>
conv_grad_op
(
ctx
,
input_shape
,
filter_shape
,
output_shape
,
padding_
,
stride_
,
dilation_
,
group_
);
conv_grad_op
.
Backward
(
ctx
,
input
,
filter_nhwc
,
output_grad
,
&
input_grad
,
&
filter_grad
);
TensorCopySync
(
input_grad
,
platform
::
CPUPlace
(),
cpu_input_grad
);
TensorCopySync
(
filter_grad
,
platform
::
CPUPlace
(),
cpu_filter_grad
);
}
private:
int
batch_size_
;
int
height_
;
int
width_
;
int
out_height_
;
int
out_width_
;
int
input_channels_
;
int
output_channels_
;
int
kernel_size_
;
int
stride_
;
int
padding_
;
const
int
dilation_
=
1
;
const
int
group_
=
1
;
// Forward input
framework
::
Tensor
cpu_input_
;
framework
::
Tensor
cpu_filter_nchw_
;
framework
::
Tensor
cpu_filter_nhwc_
;
// Backward input
framework
::
Tensor
cpu_output_grad_
;
};
// test for fp16, kernel = 1, output_channels = input_channels
TEST
(
CudnnNormConvFp16
,
K1S1
)
{
int
batch_size
=
4
;
int
height
=
56
;
int
width
=
56
;
int
input_channels
=
32
;
int
output_channels
=
32
;
int
kernel_size
=
1
;
int
stride
=
1
;
CudnnNormConvolutionTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
input_channels
,
output_channels
,
kernel_size
,
stride
);
test
.
CheckForward
(
1e-3
,
true
);
test
.
CheckBackward
(
1e-3
,
true
);
}
// test for fp16, kernel = 3, output_channels = input_channels
TEST
(
CudnnNormConvFp16
,
K3S1
)
{
int
batch_size
=
4
;
int
height
=
56
;
int
width
=
56
;
int
input_channels
=
32
;
int
output_channels
=
32
;
int
kernel_size
=
3
;
int
stride
=
1
;
CudnnNormConvolutionTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
input_channels
,
output_channels
,
kernel_size
,
stride
);
test
.
CheckForward
(
1e-3
,
true
);
test
.
CheckBackward
(
1e-3
,
true
);
}
// test for fp16, kernel = 1, output_channels = input_channels * 4
TEST
(
CudnnNormConvFp16
,
K1S1O4
)
{
int
batch_size
=
4
;
int
height
=
56
;
int
width
=
56
;
int
input_channels
=
32
;
int
output_channels
=
128
;
int
kernel_size
=
1
;
int
stride
=
1
;
CudnnNormConvolutionTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
input_channels
,
output_channels
,
kernel_size
,
stride
);
test
.
CheckForward
(
1e-3
,
true
);
test
.
CheckBackward
(
1e-3
,
true
);
}
// test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4
TEST
(
CudnnNormConvFp16
,
K1S2O4
)
{
int
batch_size
=
4
;
int
height
=
8
;
int
width
=
8
;
int
input_channels
=
32
;
int
output_channels
=
128
;
int
kernel_size
=
1
;
int
stride
=
2
;
CudnnNormConvolutionTester
<
paddle
::
platform
::
float16
>
test
(
batch_size
,
height
,
width
,
input_channels
,
output_channels
,
kernel_size
,
stride
);
platform
::
CUDADeviceContext
*
ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
if
(
ctx
->
GetComputeCapability
()
<=
70
)
{
ASSERT_THROW
(
test
.
CheckForward
(
1e-3
,
true
),
paddle
::
platform
::
EnforceNotMet
);
ASSERT_THROW
(
test
.
CheckBackward
(
1e-3
),
paddle
::
platform
::
EnforceNotMet
);
}
else
{
ASSERT_NO_THROW
(
test
.
CheckForward
(
1e-3
,
true
));
ASSERT_NO_THROW
(
test
.
CheckBackward
(
1e-3
));
}
}
paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
namespace
dynload
=
platform
::
dynload
;
template
<
typename
T
>
using
BatchNormParamType
=
typename
platform
::
CudnnDataType
<
T
>::
BatchNormParamType
;
#if CUDNN_VERSION >= 8000
template
<
typename
T
>
struct
ScaleBiasAddReluArgs
{
ScaleBiasAddReluArgs
()
{
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
param_dtype
=
platform
::
CudnnDataType
<
BatchNormParamType
<
T
>>::
type
;
format
=
CUDNN_TENSOR_NHWC
;
}
void
Set
(
const
std
::
string
&
act_type
,
const
std
::
vector
<
int
>
&
data_shape
,
const
std
::
vector
<
int
>
&
param_shape
,
const
std
::
vector
<
int
>
&
bitmask_shape
)
{
PADDLE_ENFORCE_EQ
(
data_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of data_shape is expected to 4. But recieved "
"data_shape's size is %d, data_shape is [%s]."
,
data_shape
.
size
(),
framework
::
make_ddim
(
data_shape
)));
PADDLE_ENFORCE_EQ
(
param_shape
.
size
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s]."
,
param_shape
.
size
(),
framework
::
make_ddim
(
param_shape
)));
PADDLE_ENFORCE_EQ
(
bitmask_shape
.
size
(),
3U
,
platform
::
errors
::
InvalidArgument
(
"The size of bitmask_shape is expected to 3. But recieved "
"bitmask_shape's size is %d, bitmask_shape is [%s]."
,
bitmask_shape
.
size
(),
framework
::
make_ddim
(
bitmask_shape
)));
in_desc
.
set
(
data_shape
,
format
,
dtype
);
out_desc
.
set
(
data_shape
,
format
,
dtype
);
equiv_scale_bias_desc
.
set
(
param_shape
,
format
,
dtype
);
scale_bias_mean_var_desc
.
set
(
param_shape
,
format
,
param_dtype
);
bitmask_desc
.
set
(
bitmask_shape
,
format
,
CUDNN_DATA_INT32
);
// set activation desc
cudnnActivationMode_t
mode
=
CUDNN_ACTIVATION_IDENTITY
;
if
(
act_type
!=
""
)
{
PADDLE_ENFORCE_EQ
(
act_type
,
"relu"
,
platform
::
errors
::
InvalidArgument
(
"Only relu activation supported in normalized convolution."
));
mode
=
CUDNN_ACTIVATION_RELU
;
}
double
dummy_clip
=
0.0
;
activation_desc
.
set
(
mode
,
dummy_clip
);
}
cudnnDataType_t
dtype
;
cudnnDataType_t
param_dtype
;
cudnnTensorFormat_t
format
;
platform
::
TensorDescriptor
in_desc
;
platform
::
TensorDescriptor
out_desc
;
platform
::
TensorDescriptor
equiv_scale_bias_desc
;
platform
::
TensorDescriptor
scale_bias_mean_var_desc
;
platform
::
TensorDescriptor
bitmask_desc
;
platform
::
ActivationDescriptor
activation_desc
;
};
template
<
typename
T
>
class
CudnnScaleBiasAddRelu
{
public:
CudnnScaleBiasAddRelu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
string
&
act_type
,
bool
fuse_add
,
bool
has_shortcut
,
const
std
::
vector
<
int
>
&
data_shape
,
const
std
::
vector
<
int
>
&
param_shape
,
const
std
::
vector
<
int
>
&
bitmask_shape
)
:
fwd_op_
(
CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK
),
bwd_op_
(
CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM
)
{
fuse_add_
=
fuse_add
;
has_shortcut_
=
has_shortcut
;
args_
.
Set
(
act_type
,
data_shape
,
param_shape
,
bitmask_shape
);
}
~
CudnnScaleBiasAddRelu
()
{}
void
Forward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
x
,
const
Tensor
&
x_scale
,
const
Tensor
&
x_bias
,
const
Tensor
*
z
,
const
Tensor
*
z_scale
,
const
Tensor
*
z_bias
,
Tensor
*
out
,
Tensor
*
bitmask
)
{
ForwardInit
(
ctx
);
auto
handle
=
ctx
.
cudnn_handle
();
auto
place
=
ctx
.
GetPlace
();
auto
workspace_handle
=
ctx
.
cudnn_workspace_handle
();
fwd_workspace_byte_
=
fwd_op_
.
GetWorkspaceSizeInBytes
(
handle
);
// Set variant_param
// input ptr
T
*
x_ptr
=
const_cast
<
T
*>
(
x
.
data
<
T
>
());
T
*
x_scale_ptr
=
const_cast
<
T
*>
(
x_scale
.
data
<
T
>
());
T
*
x_bias_ptr
=
const_cast
<
T
*>
(
x_bias
.
data
<
T
>
());
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_XDATA
,
x_ptr
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_EQSCALE
,
x_scale_ptr
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_EQBIAS
,
x_bias_ptr
);
if
(
has_shortcut_
)
{
T
*
z_ptr
=
const_cast
<
T
*>
(
z
->
data
<
T
>
());
T
*
z_scale_ptr
=
const_cast
<
T
*>
(
z_scale
->
data
<
T
>
());
T
*
z_bias_ptr
=
const_cast
<
T
*>
(
z_bias
->
data
<
T
>
());
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_ZDATA
,
z_ptr
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_Z_EQSCALE
,
z_scale_ptr
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_Z_EQBIAS
,
z_bias_ptr
);
}
else
{
if
(
fuse_add_
)
{
T
*
z_ptr
=
const_cast
<
T
*>
(
z
->
data
<
T
>
());
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_ZDATA
,
z_ptr
);
}
}
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES
,
&
fwd_workspace_byte_
);
// output ptr
T
*
out_ptr
=
out
->
mutable_data
<
T
>
(
place
);
int32_t
*
bitmask_ptr
=
bitmask
->
mutable_data
<
int32_t
>
(
place
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_YDATA
,
out_ptr
);
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_ACTIVATION_BITMASK
,
bitmask_ptr
);
workspace_handle
.
RunFunc
(
[
&
](
void
*
workspace_ptr
)
{
// workspace ptr
fwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
workspace_ptr
);
// workspace ptr
fwd_op_
.
Execute
(
handle
);
},
fwd_workspace_byte_
);
}
void
Backward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
dy
,
const
Tensor
&
x
,
const
Tensor
&
scale
,
const
Tensor
&
bias
,
const
Tensor
&
saved_mean
,
const
Tensor
&
saved_invstd
,
const
Tensor
*
bitmask
,
Tensor
*
dx
,
Tensor
*
dz
,
Tensor
*
dscale
,
Tensor
*
dbias
,
double
eps
)
{
BackwardInit
(
ctx
);
auto
handle
=
ctx
.
cudnn_handle
();
auto
place
=
ctx
.
GetPlace
();
auto
workspace_handle
=
ctx
.
cudnn_workspace_handle
();
bwd_workspace_byte_
=
bwd_op_
.
GetWorkspaceSizeInBytes
(
handle
);
// Set variant_param
// input ptr
T
*
dy_ptr
=
const_cast
<
T
*>
(
dy
.
data
<
T
>
());
T
*
x_ptr
=
const_cast
<
T
*>
(
x
.
data
<
T
>
());
float
*
scale_ptr
=
const_cast
<
float
*>
(
scale
.
data
<
float
>
());
float
*
bias_ptr
=
const_cast
<
float
*>
(
bias
.
data
<
float
>
());
float
*
saved_mean_ptr
=
const_cast
<
float
*>
(
saved_mean
.
data
<
float
>
());
float
*
saved_invstd_ptr
=
const_cast
<
float
*>
(
saved_invstd
.
data
<
float
>
());
int32_t
*
bitmask_ptr
=
bitmask
?
const_cast
<
int32_t
*>
(
bitmask
->
data
<
int32_t
>
())
:
nullptr
;
T
*
dx_ptr
=
dx
->
mutable_data
<
T
>
(
place
);
T
*
dz_ptr
=
dz
?
dz
->
mutable_data
<
T
>
(
place
)
:
nullptr
;
float
*
dscale_ptr
=
dscale
?
dscale
->
mutable_data
<
float
>
(
place
)
:
nullptr
;
float
*
dbias_ptr
=
dbias
?
dbias
->
mutable_data
<
float
>
(
place
)
:
nullptr
;
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_XDATA
,
x_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_DYDATA
,
dy_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SCALE
,
scale_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_BIAS
,
bias_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SAVED_MEAN
,
saved_mean_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_SAVED_INVSTD
,
saved_invstd_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_ACTIVATION_BITMASK
,
bitmask_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES
,
&
bwd_workspace_byte_
);
// output ptr
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_DXDATA
,
dx_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_DSCALE
,
dscale_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_BN_DBIAS
,
dbias_ptr
);
bwd_op_
.
SetOpVariantParamAttrPtr
<
double
>
(
CUDNN_SCALAR_DOUBLE_BN_EPSILON
,
&
eps
);
if
(
has_shortcut_
||
fuse_add_
)
{
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_DZDATA
,
dz_ptr
);
}
workspace_handle
.
RunFunc
(
[
&
](
void
*
workspace_ptr
)
{
// workspace ptr
bwd_op_
.
SetOpVariantParamAttrPtr
(
CUDNN_PTR_WORKSPACE
,
workspace_ptr
);
// workspace ptr
bwd_op_
.
Execute
(
handle
);
},
bwd_workspace_byte_
);
}
private:
void
ForwardInit
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
// Set constant_param
fwd_op_
.
SetOpConstParamAttr
(
{
CUDNN_PARAM_XDATA_PLACEHOLDER
,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER
,
CUDNN_PARAM_YDATA_PLACEHOLDER
,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
if
(
has_shortcut_
)
{
fwd_op_
.
SetOpConstParamAttr
(
{
CUDNN_PARAM_ZDATA_PLACEHOLDER
,
CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
}
else
if
(
fuse_add_
)
{
fwd_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_ZDATA_PLACEHOLDER
,
CUDNN_PTR_16B_ALIGNED
);
}
// input desc
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_XDESC
,
args_
.
in_desc
.
desc
());
if
(
has_shortcut_
||
fuse_add_
)
{
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_ZDESC
,
args_
.
in_desc
.
desc
());
}
// equiv scale/bias desc
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_EQSCALEBIAS_DESC
,
args_
.
equiv_scale_bias_desc
.
desc
());
if
(
has_shortcut_
)
{
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC
,
args_
.
equiv_scale_bias_desc
.
desc
());
}
// output desc
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_YDESC
,
args_
.
out_desc
.
desc
());
// bitmask desc
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_ACTIVATION_BITMASK_DESC
,
args_
.
bitmask_desc
.
desc
());
// activation desc
fwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_ACTIVATION_DESC
,
args_
.
activation_desc
.
desc
());
// others
fwd_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
}
void
BackwardInit
(
const
platform
::
CUDADeviceContext
&
ctx
)
{
// Set constant_param
bwd_op_
.
SetOpConstParamAttr
(
{
CUDNN_PARAM_XDATA_PLACEHOLDER
,
CUDNN_PARAM_DYDATA_PLACEHOLDER
,
CUDNN_PARAM_DXDATA_PLACEHOLDER
,
CUDNN_PARAM_BN_SCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_BIAS_PLACEHOLDER
,
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER
,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER
,
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER
,
CUDNN_PARAM_BN_DBIAS_PLACEHOLDER
,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER
},
CUDNN_PTR_16B_ALIGNED
);
if
(
has_shortcut_
||
fuse_add_
)
{
bwd_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_DZDATA_PLACEHOLDER
,
CUDNN_PTR_16B_ALIGNED
);
}
// input desc
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_XDESC
,
args_
.
in_desc
.
desc
());
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_DXDESC
,
args_
.
in_desc
.
desc
());
if
(
has_shortcut_
||
fuse_add_
)
{
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_DZDESC
,
args_
.
in_desc
.
desc
());
}
// scale/bias/mean/var desc for backward
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC
,
args_
.
scale_bias_mean_var_desc
.
desc
());
// output desc
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_DYDESC
,
args_
.
out_desc
.
desc
());
// bitmask desc
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_ACTIVATION_BITMASK_DESC
,
args_
.
bitmask_desc
.
desc
());
// activation desc
bwd_op_
.
SetOpConstParamDesc
(
CUDNN_PARAM_ACTIVATION_DESC
,
args_
.
activation_desc
.
desc
());
// others
bwd_op_
.
SetOpConstParamAttr
(
CUDNN_PARAM_BN_MODE
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
);
}
bool
fuse_add_
=
false
;
bool
has_shortcut_
=
false
;
size_t
fwd_workspace_byte_
;
size_t
bwd_workspace_byte_
;
ScaleBiasAddReluArgs
<
T
>
args_
;
CudnnFusionOp
fwd_op_
;
CudnnFusionOp
bwd_op_
;
};
#endif
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/resnet_unit_op.cc
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
// Shape of bitmask
static
framework
::
DDim
GetBitmaskDims
(
std
::
vector
<
int
>
out_shape
)
{
int
c
=
out_shape
.
back
();
int64_t
nhw
=
std
::
accumulate
(
out_shape
.
begin
(),
out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
/
c
;
int32_t
c_int32_elems
=
((
c
+
63
)
&
~
63
)
/
32
;
int32_t
nhw_int32_elems
=
((
nhw
+
31
)
&
~
31
);
std
::
vector
<
int
>
bitmask_shape
=
{
nhw_int32_elems
,
c_int32_elems
,
1
};
return
framework
::
make_ddim
(
bitmask_shape
);
}
class
ResNetUnitOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
// Check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"FilterX"
),
"Input"
,
"FilterX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ScaleX"
),
"Input"
,
"ScaleX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BiasX"
),
"Input"
,
"BiasX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"MeanX"
),
"Input"
,
"MeanX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"VarX"
),
"Input"
,
"VarX"
,
"ResNetUnitOp"
);
bool
fuse_add
=
ctx
->
Attrs
().
Get
<
bool
>
(
"fuse_add"
);
bool
has_shortcut
=
ctx
->
Attrs
().
Get
<
bool
>
(
"has_shortcut"
);
if
(
fuse_add
||
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Z"
),
"Input"
,
"Z"
,
"ResNetUnitOp"
);
}
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"FilterZ"
),
"Input"
,
"FilterZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ScaleZ"
),
"Input"
,
"ScaleZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BiasZ"
),
"Input"
,
"BiasZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"MeanZ"
),
"Input"
,
"MeanZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"VarZ"
),
"Input"
,
"VarZ"
,
"ResNetUnitOp"
);
}
// Check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BitMask"
),
"Output"
,
"BitMask"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ConvX"
),
"Output"
,
"ConvX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMeanX"
),
"Output"
,
"SavedMeanX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedInvstdX"
),
"Output"
,
"SavedInvstdX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"RunningMeanX"
),
"Output"
,
"RunningMeanX"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"RunningVarX"
),
"Output"
,
"RunningVarX"
,
"ResNetUnitOp"
);
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ConvZ"
),
"Output"
,
"ConvZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMeanZ"
),
"Output"
,
"SavedMeanZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedInvstdZ"
),
"Output"
,
"SavedInvstdZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"RunningMeanZ"
),
"Output"
,
"RunningMeanZ"
,
"ResNetUnitOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"RunningVarZ"
),
"Output"
,
"RunningVarZ"
,
"ResNetUnitOp"
);
}
// make sure Mean/RunningMean and Var/RunningVar share memory
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"MeanX"
)[
0
],
ctx
->
Outputs
(
"RunningMeanX"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"MeanX and RunningMeanX should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"VarX"
)[
0
],
ctx
->
Outputs
(
"RunningVarX"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"VarX and RunningVarX should share the same memory"
));
if
(
has_shortcut
)
{
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"MeanZ"
)[
0
],
ctx
->
Outputs
(
"RunningMeanZ"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"MeanZ and RunningMeanZ should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"VarZ"
)[
0
],
ctx
->
Outputs
(
"RunningVarZ"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"VarZ and RunningVarZ should share the same memory"
));
}
// Check dims of inputs
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
w_dims
=
ctx
->
GetInputDim
(
"FilterX"
);
const
auto
bn_param_dims
=
ctx
->
GetInputDim
(
"ScaleX"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]"
,
x_dims
,
x_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of filter "
"must equal to 4."
"But received: the shape of filter "
"= [%s], the dimension of filter = [%d] "
,
w_dims
,
w_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
bn_param_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of bn param "
"must equal to 4."
"But received: the shape of bn param "
"= [%s], the dimension of bn param = [%d] "
,
bn_param_dims
,
bn_param_dims
.
size
()));
auto
data_format
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_EQ
(
data_format
,
"NHWC"
,
platform
::
errors
::
InvalidArgument
(
"The data format must equal to NHWC. "
"But received: the data format "
"= [%s]"
,
data_format
));
// Calculate the dims of outputs
int
batch
=
x_dims
[
0
];
int
output_channel
=
w_dims
[
0
];
int
filter_size
=
w_dims
[
2
];
int
stride
=
ctx
->
Attrs
().
Get
<
int
>
(
"stride"
);
int
padding
=
ctx
->
Attrs
().
Get
<
int
>
(
"padding"
);
int
out_h
=
(
x_dims
[
1
]
+
padding
*
2
-
filter_size
)
/
stride
+
1
;
int
out_w
=
(
x_dims
[
2
]
+
padding
*
2
-
filter_size
)
/
stride
+
1
;
std
::
vector
<
int
>
out_shape
=
{
batch
,
out_h
,
out_w
,
output_channel
};
auto
y_dims
=
framework
::
make_ddim
(
out_shape
);
auto
bitmask_dims
=
GetBitmaskDims
(
out_shape
);
// Set dims of outputs
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
SetOutputDim
(
"BitMask"
,
bitmask_dims
);
ctx
->
SetOutputDim
(
"ConvX"
,
y_dims
);
ctx
->
SetOutputDim
(
"SavedMeanX"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"SavedInvstdX"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"RunningMeanX"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"RunningVarX"
,
bn_param_dims
);
if
(
has_shortcut
)
{
ctx
->
SetOutputDim
(
"ConvZ"
,
y_dims
);
ctx
->
SetOutputDim
(
"SavedMeanZ"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"SavedInvstdZ"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"RunningMeanZ"
,
bn_param_dims
);
ctx
->
SetOutputDim
(
"RunningVarZ"
,
bn_param_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto
bn_param_type
=
framework
::
proto
::
VarType
::
FP32
;
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"ScaleX"
)
->
type
(),
platform
::
errors
::
InvalidArgument
(
"Scale input should be of float type"
));
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"BiasX"
)
->
type
(),
platform
::
errors
::
InvalidArgument
(
"Bias input should be of float type"
));
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
);
}
};
class
ResNetUnitOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"The input 1 tensor"
);
AddInput
(
"FilterX"
,
"Filter tensor of input 1"
);
AddInput
(
"ScaleX"
,
"Scale tensor of input 1 used in batchnorm"
);
AddInput
(
"BiasX"
,
"Bias tensor of input 1 used in batchnorm"
);
AddInput
(
"MeanX"
,
"Mean tensor of input 1 used in batchnorm"
);
AddInput
(
"VarX"
,
"Variance tensor of input 1 used in batchnorm"
);
AddInput
(
"Z"
,
"The input 2 tensor"
).
AsDispensable
();
AddInput
(
"FilterZ"
,
"Filter tensor of input 2"
).
AsDispensable
();
AddInput
(
"ScaleZ"
,
"Scale tensor of input 2"
).
AsDispensable
();
AddInput
(
"BiasZ"
,
"Bias tensor of input 2"
).
AsDispensable
();
AddInput
(
"MeanZ"
,
"Mean tensor of input 2"
).
AsDispensable
();
AddInput
(
"VarZ"
,
"Variance tensor of input 2"
).
AsDispensable
();
AddOutput
(
"Y"
,
"The result of the resnet unit"
);
AddOutput
(
"BitMask"
,
"The bitmask generated after relu"
);
AddOutput
(
"ConvX"
,
"The output of input 1 after conv"
);
AddOutput
(
"SavedMeanX"
,
"Mean of input 1 in the current batch"
);
AddOutput
(
"SavedInvstdX"
,
"Invstd of input 1 in the current batch"
);
AddOutput
(
"RunningMeanX"
,
"Shared memory with MeanX"
);
AddOutput
(
"RunningVarX"
,
"Shared memory with VarX"
);
AddOutput
(
"ConvZ"
,
"The output of input 2 after conv"
).
AsDispensable
();
AddOutput
(
"SavedMeanZ"
,
"Mean of input 1 in the current batch"
)
.
AsDispensable
();
AddOutput
(
"SavedInvstdZ"
,
"Invstd of input 1 in the current batch"
)
.
AsDispensable
();
AddOutput
(
"RunningMeanZ"
,
"Shared memory with MeanZ"
).
AsDispensable
();
AddOutput
(
"RunningVarZ"
,
"Shared memory with VarZ"
).
AsDispensable
();
AddAttr
<
int
>
(
"stride"
,
""
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"stride_z"
,
""
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"padding"
,
""
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dilation"
,
""
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"group"
,
""
).
SetDefault
(
1
);
AddAttr
<
float
>
(
"momentum"
,
""
).
SetDefault
(
0.9
);
AddAttr
<
float
>
(
"epsilon"
,
""
).
SetDefault
(
1e-5
);
AddAttr
<
std
::
string
>
(
"data_format"
,
""
).
SetDefault
(
"NHWC"
);
AddAttr
<
bool
>
(
"fuse_add"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"has_shortcut"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_global_stats"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_addto"
,
""
).
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"act_type"
,
"The activation type to be fused."
)
.
SetDefault
(
"relu"
);
AddComment
(
R"DOC(
Fusion op of the basic unit of resnet block.
The implementation is based on the latest fusion op interface in cuDNN v8.0.
For more details:
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t
)DOC"
);
}
};
class
ResNetUnitGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
// check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"FilterX"
),
"Input"
,
"FilterX"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ConvX"
),
"Input"
,
"ConvX"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ScaleX"
),
"Input"
,
"ScaleX"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BiasX"
),
"Input"
,
"BiasX"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMeanX"
),
"Input"
,
"SavedMeanX"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedInvstdX"
),
"Input"
,
"SavedInvstdX"
,
"ResNetUnitGradOp"
);
bool
fuse_add
=
ctx
->
Attrs
().
Get
<
bool
>
(
"fuse_add"
);
bool
has_shortcut
=
ctx
->
Attrs
().
Get
<
bool
>
(
"has_shortcut"
);
if
(
fuse_add
||
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Z"
),
"Input"
,
"Z"
,
"ResNetUnitGradOp"
);
}
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"FilterZ"
),
"Input"
,
"FilterZ"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ConvZ"
),
"Input"
,
"ConvZ"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ScaleZ"
),
"Input"
,
"ScaleZ"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BiasZ"
),
"Input"
,
"BiasZ"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMeanZ"
),
"Input"
,
"SavedMeanZ"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedInvstdZ"
),
"Input"
,
"SavedInvstdZ"
,
"ResNetUnitGradOp"
);
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BitMask"
),
"Input"
,
"BitMask"
,
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input"
,
framework
::
GradVarName
(
"Y"
),
"ResNetUnitGradOp"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"FilterX"
)),
"Output"
,
framework
::
GradVarName
(
"FilterX"
),
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"ScaleX"
)),
"Output"
,
framework
::
GradVarName
(
"ScaleX"
),
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"BiasX"
)),
"Output"
,
framework
::
GradVarName
(
"BiasX"
),
"ResNetUnitGradOp"
);
if
(
fuse_add
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Z"
)),
"Output"
,
framework
::
GradVarName
(
"Z"
),
"ResNetUnitGradOp"
);
}
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"FilterZ"
)),
"Output"
,
framework
::
GradVarName
(
"FilterZ"
),
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"ScaleZ"
)),
"Output"
,
framework
::
GradVarName
(
"ScaleZ"
),
"ResNetUnitGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"BiasZ"
)),
"Output"
,
framework
::
GradVarName
(
"BiasZ"
),
"ResNetUnitGradOp"
);
}
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
filter_x_dims
=
ctx
->
GetInputDim
(
"FilterX"
);
const
auto
param_dims
=
ctx
->
GetInputDim
(
"ScaleX"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FilterX"
),
filter_x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"ScaleX"
),
param_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"BiasX"
),
param_dims
);
if
(
fuse_add
||
has_shortcut
)
{
const
auto
z_dims
=
ctx
->
GetInputDim
(
"Z"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Z"
),
z_dims
);
}
if
(
has_shortcut
)
{
const
auto
filter_z_dims
=
ctx
->
GetInputDim
(
"FilterZ"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FilterZ"
),
filter_z_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"ScaleZ"
),
param_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"BiasZ"
),
param_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
platform
::
errors
::
NotFound
(
"Can not find Y@GRAD in the execution context."
));
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout
,
library
);
}
};
template
<
typename
T
>
class
ResNetUnitGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"resnet_unit_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"FilterX"
,
this
->
Input
(
"FilterX"
));
op
->
SetInput
(
"ConvX"
,
this
->
Output
(
"ConvX"
));
op
->
SetInput
(
"ScaleX"
,
this
->
Input
(
"ScaleX"
));
op
->
SetInput
(
"BiasX"
,
this
->
Input
(
"BiasX"
));
op
->
SetInput
(
"SavedMeanX"
,
this
->
Output
(
"SavedMeanX"
));
op
->
SetInput
(
"SavedInvstdX"
,
this
->
Output
(
"SavedInvstdX"
));
op
->
SetInput
(
"Z"
,
this
->
Input
(
"Z"
));
op
->
SetInput
(
"FilterZ"
,
this
->
Input
(
"FilterZ"
));
op
->
SetInput
(
"ConvZ"
,
this
->
Output
(
"ConvZ"
));
op
->
SetInput
(
"ScaleZ"
,
this
->
Input
(
"ScaleZ"
));
op
->
SetInput
(
"BiasZ"
,
this
->
Input
(
"BiasZ"
));
op
->
SetInput
(
"SavedMeanZ"
,
this
->
Output
(
"SavedMeanZ"
));
op
->
SetInput
(
"SavedInvstdZ"
,
this
->
Output
(
"SavedInvstdZ"
));
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
"BitMask"
,
this
->
Output
(
"BitMask"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"FilterX"
),
this
->
InputGrad
(
"FilterX"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"ScaleX"
),
this
->
InputGrad
(
"ScaleX"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"BiasX"
),
this
->
InputGrad
(
"BiasX"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Z"
),
this
->
InputGrad
(
"Z"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"FilterZ"
),
this
->
InputGrad
(
"FilterZ"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"ScaleZ"
),
this
->
InputGrad
(
"ScaleZ"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"BiasZ"
),
this
->
InputGrad
(
"BiasZ"
));
}
};
class
ResNetUnitOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
const
override
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
resnet_unit
,
ops
::
ResNetUnitOp
,
ops
::
ResNetUnitOpMaker
,
ops
::
ResNetUnitOpInferVarType
,
ops
::
ResNetUnitGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ResNetUnitGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
resnet_unit_grad
,
ops
::
ResNetUnitGradOp
);
paddle/fluid/operators/fused/resnet_unit_op.cu
0 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
ResNetUnitKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use CUDAPlace."
));
PADDLE_ENFORCE_EQ
(
platform
::
CudnnDataType
<
T
>::
type
,
CUDNN_DATA_HALF
,
platform
::
errors
::
Unavailable
(
"ResNetUnitOp only supports float16 for now."
));
// input x
const
Tensor
*
input_x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
filter_x
=
ctx
.
Input
<
Tensor
>
(
"FilterX"
);
const
Tensor
*
scale_x
=
ctx
.
Input
<
Tensor
>
(
"ScaleX"
);
const
Tensor
*
bias_x
=
ctx
.
Input
<
Tensor
>
(
"BiasX"
);
// norm conv
Tensor
*
conv_out_x
=
ctx
.
Output
<
Tensor
>
(
"ConvX"
);
// bn finalize
Tensor
*
saved_mean_x
=
ctx
.
Output
<
Tensor
>
(
"SavedMeanX"
);
Tensor
*
saved_invstd_x
=
ctx
.
Output
<
Tensor
>
(
"SavedInvstdX"
);
Tensor
*
running_mean_x
=
ctx
.
Output
<
Tensor
>
(
"RunningMeanX"
);
Tensor
*
running_var_x
=
ctx
.
Output
<
Tensor
>
(
"RunningVarX"
);
// sbar
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Tensor
*
bitmask
=
ctx
.
Output
<
Tensor
>
(
"BitMask"
);
// attrs
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
dilation
=
ctx
.
Attr
<
int
>
(
"dilation"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
bool
has_shortcut
=
ctx
.
Attr
<
bool
>
(
"has_shortcut"
);
bool
fuse_add
=
ctx
.
Attr
<
bool
>
(
"fuse_add"
);
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
bool
is_train
=
!
is_test
&&
!
use_global_stats
;
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
auto
input_x_shape
=
framework
::
vectorize
<
int
>
(
input_x
->
dims
());
auto
filter_x_shape
=
framework
::
vectorize
<
int
>
(
filter_x
->
dims
());
auto
param_dims
=
scale_x
->
dims
();
auto
param_shape
=
framework
::
vectorize
<
int
>
(
scale_x
->
dims
());
auto
output_shape
=
framework
::
vectorize
<
int
>
(
output
->
dims
());
auto
bitmask_shape
=
framework
::
vectorize
<
int
>
(
bitmask
->
dims
());
int
output_channel
=
filter_x_shape
[
0
];
int64_t
ele_count
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
/
output_channel
;
auto
place
=
ctx
.
GetPlace
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
// 1. Conv
Tensor
sum_x
;
Tensor
sum_of_squares_x
;
sum_x
.
Resize
(
param_dims
);
sum_of_squares_x
.
Resize
(
param_dims
);
CudnnNormConvolution
<
T
>
conv_x_op
(
dev_ctx
,
input_x_shape
,
filter_x_shape
,
output_shape
,
padding
,
stride
,
dilation
,
group
);
conv_x_op
.
Forward
(
dev_ctx
,
*
input_x
,
*
filter_x
,
conv_out_x
,
&
sum_x
,
&
sum_of_squares_x
);
// 2. BN
Tensor
equiv_scale_x
;
Tensor
equiv_bias_x
;
equiv_scale_x
.
Resize
(
param_dims
);
equiv_bias_x
.
Resize
(
param_dims
);
CudnnBNStatsFinalize
<
T
>
bn_x_op
(
dev_ctx
,
param_shape
);
bn_x_op
.
Forward
(
dev_ctx
,
sum_x
,
sum_of_squares_x
,
*
scale_x
,
*
bias_x
,
saved_mean_x
,
saved_invstd_x
,
running_mean_x
,
running_var_x
,
&
equiv_scale_x
,
&
equiv_bias_x
,
eps
,
momentum
,
ele_count
,
is_train
);
// 3. scale + bias + add + relu
CudnnScaleBiasAddRelu
<
T
>
sbar_op
(
dev_ctx
,
act_type
,
fuse_add
,
has_shortcut
,
output_shape
,
param_shape
,
bitmask_shape
);
if
(
has_shortcut
)
{
// input z
const
Tensor
*
input_z
=
ctx
.
Input
<
Tensor
>
(
"Z"
);
const
Tensor
*
filter_z
=
ctx
.
Input
<
Tensor
>
(
"FilterZ"
);
const
Tensor
*
scale_z
=
ctx
.
Input
<
Tensor
>
(
"ScaleZ"
);
const
Tensor
*
bias_z
=
ctx
.
Input
<
Tensor
>
(
"BiasZ"
);
// norm conv
Tensor
*
conv_out_z
=
ctx
.
Output
<
Tensor
>
(
"ConvZ"
);
// bn finalize
Tensor
*
saved_mean_z
=
ctx
.
Output
<
Tensor
>
(
"SavedMeanZ"
);
Tensor
*
saved_invstd_z
=
ctx
.
Output
<
Tensor
>
(
"SavedInvstdZ"
);
Tensor
*
running_mean_z
=
ctx
.
Output
<
Tensor
>
(
"RunningMeanZ"
);
Tensor
*
running_var_z
=
ctx
.
Output
<
Tensor
>
(
"RunningVarZ"
);
auto
input_z_shape
=
framework
::
vectorize
<
int
>
(
input_z
->
dims
());
auto
filter_z_shape
=
framework
::
vectorize
<
int
>
(
filter_z
->
dims
());
// 3.1 Conv for second input
Tensor
sum_z
;
Tensor
sum_of_squares_z
;
sum_z
.
Resize
(
param_dims
);
sum_of_squares_z
.
Resize
(
param_dims
);
CudnnNormConvolution
<
T
>
conv_z_op
(
dev_ctx
,
input_z_shape
,
filter_z_shape
,
output_shape
,
padding
,
stride_z
,
dilation
,
group
);
conv_z_op
.
Forward
(
dev_ctx
,
*
input_z
,
*
filter_z
,
conv_out_z
,
&
sum_z
,
&
sum_of_squares_z
);
// 3.2 BN for second input
Tensor
equiv_scale_z
;
Tensor
equiv_bias_z
;
equiv_scale_z
.
Resize
(
param_dims
);
equiv_bias_z
.
Resize
(
param_dims
);
CudnnBNStatsFinalize
<
T
>
bn_z_op
(
dev_ctx
,
param_shape
);
bn_z_op
.
Forward
(
dev_ctx
,
sum_z
,
sum_of_squares_z
,
*
scale_z
,
*
bias_z
,
saved_mean_z
,
saved_invstd_z
,
running_mean_z
,
running_var_z
,
&
equiv_scale_z
,
&
equiv_bias_z
,
eps
,
momentum
,
ele_count
,
is_train
);
// 3.3 sbar
sbar_op
.
Forward
(
dev_ctx
,
*
conv_out_x
,
equiv_scale_x
,
equiv_bias_x
,
conv_out_z
,
&
equiv_scale_z
,
&
equiv_bias_z
,
output
,
bitmask
);
}
else
{
const
Tensor
*
input_z
=
fuse_add
?
ctx
.
Input
<
Tensor
>
(
"Z"
)
:
nullptr
;
sbar_op
.
Forward
(
dev_ctx
,
*
conv_out_x
,
equiv_scale_x
,
equiv_bias_x
,
input_z
,
nullptr
,
nullptr
,
output
,
bitmask
);
}
}
};
template
<
typename
T
>
class
ResNetUnitGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use CUDAPlace."
));
PADDLE_ENFORCE_EQ
(
platform
::
CudnnDataType
<
T
>::
type
,
CUDNN_DATA_HALF
,
platform
::
errors
::
Unavailable
(
"ResNetUnitOp only supports float16 for now."
));
const
Tensor
*
y_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
filter_x
=
ctx
.
Input
<
Tensor
>
(
"FilterX"
);
const
Tensor
*
scale_x
=
ctx
.
Input
<
Tensor
>
(
"ScaleX"
);
const
Tensor
*
bias_x
=
ctx
.
Input
<
Tensor
>
(
"BiasX"
);
const
Tensor
*
saved_mean_x
=
ctx
.
Input
<
Tensor
>
(
"SavedMeanX"
);
const
Tensor
*
saved_invstd_x
=
ctx
.
Input
<
Tensor
>
(
"SavedInvstdX"
);
const
Tensor
*
conv_out_x
=
ctx
.
Input
<
Tensor
>
(
"ConvX"
);
const
Tensor
*
output
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
const
Tensor
*
bitmask
=
ctx
.
Input
<
Tensor
>
(
"BitMask"
);
Tensor
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
filter_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"FilterX"
));
Tensor
*
scale_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"ScaleX"
));
Tensor
*
bias_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasX"
));
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
dilation
=
ctx
.
Attr
<
int
>
(
"dilation"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
bool
has_shortcut
=
ctx
.
Attr
<
bool
>
(
"has_shortcut"
);
bool
fuse_add
=
ctx
.
Attr
<
bool
>
(
"fuse_add"
);
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
auto
x_shape
=
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
filter_x_shape
=
framework
::
vectorize
<
int
>
(
filter_x
->
dims
());
auto
param_shape
=
framework
::
vectorize
<
int
>
(
scale_x
->
dims
());
auto
output_shape
=
framework
::
vectorize
<
int
>
(
output
->
dims
());
auto
bitmask_shape
=
framework
::
vectorize
<
int
>
(
bitmask
->
dims
());
auto
place
=
ctx
.
GetPlace
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
// 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad
Tensor
conv_out_x_grad
;
conv_out_x_grad
.
Resize
(
conv_out_x
->
dims
());
CudnnScaleBiasAddRelu
<
T
>
sbar_x_op
(
dev_ctx
,
act_type
,
fuse_add
,
has_shortcut
,
output_shape
,
param_shape
,
bitmask_shape
);
if
(
has_shortcut
)
{
// X Z
// | |
// NormConv NormConv
// | |
// BNStatsFinalize BNStatsFinalize
// \ /
// ScaleBiasAddRelu
// |
// Y
const
Tensor
*
z
=
ctx
.
Input
<
Tensor
>
(
"Z"
);
const
Tensor
*
filter_z
=
ctx
.
Input
<
Tensor
>
(
"FilterZ"
);
const
Tensor
*
scale_z
=
ctx
.
Input
<
Tensor
>
(
"ScaleZ"
);
const
Tensor
*
bias_z
=
ctx
.
Input
<
Tensor
>
(
"BiasZ"
);
const
Tensor
*
saved_mean_z
=
ctx
.
Input
<
Tensor
>
(
"SavedMeanZ"
);
const
Tensor
*
saved_invstd_z
=
ctx
.
Input
<
Tensor
>
(
"SavedInvstdZ"
);
const
Tensor
*
conv_out_z
=
ctx
.
Input
<
Tensor
>
(
"ConvZ"
);
Tensor
*
z_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Z"
));
Tensor
*
filter_z_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"FilterZ"
));
Tensor
*
scale_z_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"ScaleZ"
));
Tensor
*
bias_z_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasZ"
));
// 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad and z_grad_temp
Tensor
z_grad_temp
;
z_grad_temp
.
Resize
(
conv_out_z
->
dims
());
sbar_x_op
.
Backward
(
dev_ctx
,
*
y_grad
,
*
conv_out_x
,
*
scale_x
,
*
bias_x
,
*
saved_mean_x
,
*
saved_invstd_x
,
bitmask
,
&
conv_out_x_grad
,
&
z_grad_temp
,
scale_x_grad
,
bias_x_grad
,
eps
);
// 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z
Tensor
conv_out_z_grad
;
conv_out_z_grad
.
Resize
(
conv_out_z
->
dims
());
CudnnScaleBiasAddRelu
<
T
>
sbar_z_op
(
dev_ctx
,
""
,
false
,
false
,
output_shape
,
param_shape
,
bitmask_shape
);
sbar_z_op
.
Backward
(
dev_ctx
,
z_grad_temp
,
*
conv_out_z
,
*
scale_z
,
*
bias_z
,
*
saved_mean_z
,
*
saved_invstd_z
,
nullptr
,
&
conv_out_z_grad
,
nullptr
,
scale_z_grad
,
bias_z_grad
,
eps
);
// 1.3 Backward of Conv for z, get z_grad and filter_z_grad
auto
z_shape
=
framework
::
vectorize
<
int
>
(
z
->
dims
());
auto
filter_z_shape
=
framework
::
vectorize
<
int
>
(
filter_z
->
dims
());
CudnnNormConvolutionGrad
<
T
>
conv_z_op
(
dev_ctx
,
z_shape
,
filter_z_shape
,
output_shape
,
padding
,
stride_z
,
dilation
,
group
);
conv_z_op
.
Backward
(
dev_ctx
,
*
z
,
*
filter_z
,
conv_out_z_grad
,
z_grad
,
filter_z_grad
);
}
else
{
// 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad (and z_grad)
Tensor
*
z_grad
=
fuse_add
?
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Z"
))
:
nullptr
;
sbar_x_op
.
Backward
(
dev_ctx
,
*
y_grad
,
*
conv_out_x
,
*
scale_x
,
*
bias_x
,
*
saved_mean_x
,
*
saved_invstd_x
,
bitmask
,
&
conv_out_x_grad
,
z_grad
,
scale_x_grad
,
bias_x_grad
,
eps
);
}
// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool
use_addto
=
ctx
.
Attr
<
bool
>
(
"use_addto"
);
CudnnNormConvolutionGrad
<
T
>
conv_x_op
(
dev_ctx
,
x_shape
,
filter_x_shape
,
output_shape
,
padding
,
stride
,
dilation
,
group
);
conv_x_op
.
Backward
(
dev_ctx
,
*
x
,
*
filter_x
,
conv_out_x_grad
,
x_grad
,
filter_x_grad
,
use_addto
);
}
};
}
// namespace operators
}
// namespace paddle
#if CUDNN_VERSION >= 8000
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
resnet_unit
,
ops
::
ResNetUnitKernel
<
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
resnet_unit_grad
,
ops
::
ResNetUnitGradKernel
<
plat
::
float16
>
);
#endif
paddle/fluid/operators/math/pooling.cu
浏览文件 @
287ca7d5
...
...
@@ -16,66 +16,141 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#ifdef __HIPCC__
#define POOLING_BLOCK_SIZE 256
#else
#define POOLING_BLOCK_SIZE 512
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
struct
FastDivModForPooling
{
public:
platform
::
FastDivMod
channel
;
platform
::
FastDivMod
width
;
platform
::
FastDivMod
height
;
explicit
HOSTDEVICE
FastDivModForPooling
(
const
int
channels
,
const
int
output_width
,
const
int
output_height
)
{
channel
=
platform
::
FastDivMod
(
channels
);
width
=
platform
::
FastDivMod
(
output_width
);
height
=
platform
::
FastDivMod
(
output_height
);
}
};
struct
FastDivModForPoolingWithMoreStaff
{
public:
platform
::
FastDivMod
channel
;
platform
::
FastDivMod
width
;
platform
::
FastDivMod
height
;
platform
::
FastDivMod
ksize_w
;
platform
::
FastDivMod
ksize_h
;
platform
::
FastDivMod
stride_w
;
platform
::
FastDivMod
stride_h
;
explicit
HOSTDEVICE
FastDivModForPoolingWithMoreStaff
(
const
int
channels
,
const
int
input_width
,
const
int
input_height
,
const
int
ksize_width
,
const
int
ksize_height
,
const
int
stride_width
,
const
int
stride_height
)
{
channel
=
platform
::
FastDivMod
(
channels
);
width
=
platform
::
FastDivMod
(
input_width
);
height
=
platform
::
FastDivMod
(
input_height
);
ksize_w
=
platform
::
FastDivMod
(
ksize_width
);
ksize_h
=
platform
::
FastDivMod
(
ksize_height
);
stride_w
=
platform
::
FastDivMod
(
stride_width
);
stride_h
=
platform
::
FastDivMod
(
stride_height
);
}
};
template
<
typename
FastDivModForPooling
>
__device__
void
OffsetPreparationFor4Dimension
(
int
index
,
bool
channel_last
,
FastDivModForPooling
divmods
,
const
int
pad_width
,
const
int
pad_height
,
const
int
aux_width
,
const
int
aux_height
,
int
*
w_offset
,
int
*
h_offset
,
int
*
c_offset
,
int
*
stride
)
{
if
(
!
channel_last
)
{
/* NCHW */
auto
input_width_divmod
=
divmods
.
width
.
Divmod
(
index
);
auto
input_height_divmod
=
divmods
.
height
.
Divmod
(
input_width_divmod
.
val
[
0
]);
auto
channel_divmod
=
divmods
.
channel
.
Divmod
(
input_height_divmod
.
val
[
0
]);
*
w_offset
=
input_width_divmod
.
val
[
1
]
+
pad_width
;
*
h_offset
=
input_height_divmod
.
val
[
1
]
+
pad_height
;
*
c_offset
=
channel_divmod
.
val
[
1
];
*
stride
=
(
channel_divmod
.
val
[
0
]
*
divmods
.
channel
.
divisor
+
*
c_offset
)
*
aux_height
*
aux_width
;
}
else
{
/* NHWC */
auto
c_divmod
=
divmods
.
channel
.
Divmod
(
index
);
auto
input_width_divmod
=
divmods
.
width
.
Divmod
(
c_divmod
.
val
[
0
]);
auto
input_height_divmod
=
divmods
.
height
.
Divmod
(
input_width_divmod
.
val
[
0
]);
*
c_offset
=
c_divmod
.
val
[
1
];
*
w_offset
=
input_width_divmod
.
val
[
1
]
+
pad_width
;
*
h_offset
=
input_height_divmod
.
val
[
1
]
+
pad_height
;
*
stride
=
input_height_divmod
.
val
[
0
]
*
aux_height
*
aux_width
*
divmods
.
channel
.
divisor
;
}
}
int
GetThreadsPerBlock
(
const
platform
::
CUDADeviceContext
&
ctx
,
int
threads_per_block
,
int64_t
numel
)
{
int
sm_count
=
ctx
.
GetSMCount
();
if
(
numel
/
(
sm_count
<<
1
)
<
threads_per_block
)
{
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads_per_block
=
platform
::
RoundToPowerOfTwo
(
numel
/
(
sm_count
<<
1
));
}
else
if
(
numel
/
(
sm_count
<<
2
)
<
threads_per_block
)
{
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads_per_block
=
platform
::
RoundToPowerOfTwo
(
numel
/
(
sm_count
<<
2
));
}
// Number of threads per block shall be larger than 64.
return
std
::
max
(
64
,
threads_per_block
);
}
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool2D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
output_data
,
bool
channel_last
=
false
)
{
__global__
void
KernelPool2D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
FastDivModForPooling
divmods
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
output_data
,
bool
channel_last
=
false
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
,
ph
,
c
,
batch_idx
;
if
(
!
channel_last
)
{
/*NCHW*/
pw
=
index
%
output_width
;
ph
=
(
index
/
output_width
)
%
output_height
;
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
batch_idx
=
index
/
output_width
/
output_height
/
channels
;
}
else
{
/*NHWC*/
c
=
index
%
channels
;
pw
=
(
index
/
channels
)
%
output_width
;
ph
=
(
index
/
channels
/
output_width
)
%
output_height
;
batch_idx
=
index
/
channels
/
output_width
/
output_height
;
}
int
hstart
,
hend
,
wstart
,
wend
;
int
w_offset
,
h_offset
,
c_offset
,
input_offset
;
OffsetPreparationFor4Dimension
<
FastDivModForPooling
>
(
index
,
channel_last
,
divmods
,
0
,
0
,
input_width
,
input_height
,
&
w_offset
,
&
h_offset
,
&
c_offset
,
&
input_offset
);
input_data
+=
input_offset
;
int
hstart
,
hend
;
int
wstart
,
wend
;
if
(
adaptive
)
{
hstart
=
AdaptStartIndex
(
ph
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
ph
,
input_height
,
output_height
);
wstart
=
AdaptStartIndex
(
pw
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
pw
,
input_width
,
output_width
);
hstart
=
AdaptStartIndex
(
h_offset
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
h_offset
,
input_height
,
output_height
);
wstart
=
AdaptStartIndex
(
w_offset
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
w_offset
,
input_width
,
output_width
);
}
else
{
hstart
=
ph
*
stride_height
-
padding_height
;
hstart
=
h_offset
*
stride_height
-
padding_height
;
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
max
(
hstart
,
0
);
wstart
=
pw
*
stride_width
-
padding_width
;
wstart
=
w_offset
*
stride_width
-
padding_width
;
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
max
(
wstart
,
0
);
}
if
(
!
channel_last
)
{
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
}
else
{
input_data
+=
batch_idx
*
input_height
*
input_width
*
channels
;
}
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
auto
input_idx
=
channel_last
?
(
h
*
input_width
+
w
)
*
channels
+
c
:
h
*
input_width
+
w
;
auto
input_idx
=
channel_last
?
(
h
*
input_width
+
w
)
*
channels
+
c_offset
:
h
*
input_width
+
w
;
pool_process
.
compute
(
input_data
[
input_idx
],
&
ele
);
}
}
...
...
@@ -85,91 +160,109 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
output_data
[
index
]
=
ele
;
}
}
template
<
typename
PoolProcess
,
typename
T
>
template
<
typename
T
,
typename
PoolProcess
>
__global__
void
KernelPool2DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
input_grad
,
bool
channel_last
=
false
)
{
const
int
nthreads
,
const
T
*
__restrict__
input_data
,
const
T
*
__restrict__
output_data
,
const
const
T
*
__restrict__
output_grad
,
const
int
output_width
,
const
int
output_height
,
const
int
input_width
,
const
int
input_height
,
const
int
ksize_width
,
const
int
ksize_height
,
const
int
stride_width
,
const
int
stride_height
,
const
int
padding_width
,
const
int
padding_height
,
FastDivModForPoolingWithMoreStaff
divmods
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
__restrict__
input_grad
,
bool
channel_last
=
false
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
,
h_offset
,
offsetC
,
batch_idx
;
if
(
!
channel_last
)
{
/* NCHW */
w_offset
=
index
%
input_width
+
padding_width
;
h_offset
=
(
index
/
input_width
)
%
input_height
+
padding_height
;
offsetC
=
(
index
/
input_width
/
input_height
)
%
channels
;
batch_idx
=
index
/
input_width
/
input_height
/
channels
;
}
else
{
/* NHWC */
offsetC
=
index
%
channels
;
w_offset
=
(
index
/
channels
)
%
input_width
+
padding_width
;
h_offset
=
(
index
/
channels
/
input_width
)
%
input_height
+
padding_height
;
batch_idx
=
index
/
channels
/
input_width
/
input_height
;
T
input
=
static_cast
<
T
>
(
0
);
T
input_grad_data
=
static_cast
<
T
>
(
0
);
int
phstart
,
phend
,
pwstart
,
pwend
;
int
w_offset
,
h_offset
,
c_offset
,
output_offset
;
OffsetPreparationFor4Dimension
<>
(
index
,
channel_last
,
divmods
,
padding_width
,
padding_height
,
output_width
,
output_height
,
&
w_offset
,
&
h_offset
,
&
c_offset
,
&
output_offset
);
if
(
pool_process
.
use_x
)
{
input
=
input_data
[
index
];
output_data
+=
output_offset
;
}
output_grad
+=
output_offset
;
int
phstart
,
phend
;
int
pwstart
,
pwend
;
if
(
adaptive
)
{
phstart
=
AdaptStartIndex
(
h_offset
,
output_height
,
input_height
);
phend
=
AdaptEndIndex
(
h_offset
,
output_height
,
input_height
);
auto
tmp_phend
=
divmods
.
height
.
Divmod
((
h_offset
+
1
)
*
output_height
);
auto
tmp_pwend
=
divmods
.
width
.
Divmod
((
w_offset
+
1
)
*
output_width
);
phstart
=
divmods
.
height
.
Div
(
h_offset
*
output_height
);
pwstart
=
divmods
.
width
.
Div
(
w_offset
*
output_width
);
phend
=
tmp_phend
.
val
[
1
]
>
0
?
tmp_phend
.
val
[
0
]
+
1
:
tmp_phend
.
val
[
0
];
pwend
=
tmp_pwend
.
val
[
1
]
>
0
?
tmp_pwend
.
val
[
0
]
+
1
:
tmp_pwend
.
val
[
0
];
pwstart
=
AdaptStartIndex
(
w_offset
,
output_width
,
input_width
);
pwend
=
AdaptEndIndex
(
w_offset
,
output_width
,
input_width
);
}
else
{
phstart
=
(
h_offset
<
ksize_height
)
?
0
:
(
h_offset
-
ksize_height
)
/
stride_height
+
1
;
pwstart
=
(
w_offset
<
ksize_width
)
?
0
:
(
w_offset
-
ksize_width
)
/
stride_width
+
1
;
phend
=
min
(
h_offset
/
stride_height
+
1
,
output_height
);
pwend
=
min
(
w_offset
/
stride_width
+
1
,
output_width
);
}
T
gradient
=
static_cast
<
T
>
(
0.0
);
T
input
=
input_data
[
index
];
int
output_stride
;
if
(
!
channel_last
)
{
output_stride
=
(
batch_idx
*
channels
+
offsetC
)
*
output_height
*
output_width
;
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
auto
ksize_w_divmod
=
divmods
.
ksize_w
.
Divmod
(
input_width
);
auto
ksize_h_divmod
=
divmods
.
ksize_h
.
Divmod
(
input_height
);
auto
tmp_width
=
ksize_w_divmod
.
val
[
1
]
>
0
?
ksize_w_divmod
.
val
[
0
]
+
1
:
ksize_w_divmod
.
val
[
0
];
auto
tmp_height
=
ksize_h_divmod
.
val
[
1
]
>
0
?
ksize_h_divmod
.
val
[
0
]
+
1
:
ksize_h_divmod
.
val
[
0
];
int
pool_size
=
tmp_height
*
tmp_width
;
int
tmp_idx
=
ph
*
output_width
+
pw
;
int
output_sub_idx
=
channel_last
?
tmp_idx
*
divmods
.
channel
.
divisor
+
c_offset
:
tmp_idx
;
T
ouput_value
=
pool_process
.
use_x
?
output_data
[
output_sub_idx
]
:
static_cast
<
T
>
(
0
);
pool_process
.
compute
(
input
,
ouput_value
,
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
input_grad_data
);
}
}
}
else
{
output_stride
=
batch_idx
*
output_height
*
output_width
*
channels
;
}
output_data
+=
output_stride
;
output_grad
+=
output_stride
;
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
int
pool_size
;
if
(
adaptive
)
{
pool_size
=
static_cast
<
int
>
(
ceil
(
static_cast
<
double
>
(
input_height
)
/
ksize_height
))
*
static_cast
<
int
>
(
ceil
(
static_cast
<
double
>
(
input_width
)
/
ksize_width
));
}
else
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
auto
stride_height_div
=
divmods
.
stride_h
.
Div
(
h_offset
-
ksize_height
);
auto
stride_width_div
=
divmods
.
stride_w
.
Div
(
w_offset
-
ksize_width
);
phstart
=
(
h_offset
<
ksize_height
)
?
0
:
stride_height_div
+
1
;
pwstart
=
(
w_offset
<
ksize_width
)
?
0
:
stride_width_div
+
1
;
phend
=
min
(
divmods
.
stride_h
.
Div
(
h_offset
)
+
1
,
output_height
);
pwend
=
min
(
divmods
.
stride_w
.
Div
(
w_offset
)
+
1
,
output_width
);
if
(
exclusive
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
tmp_idx
=
ph
*
output_width
+
pw
;
int
output_sub_idx
=
channel_last
?
tmp_idx
*
divmods
.
channel
.
divisor
+
c_offset
:
tmp_idx
;
T
ouput_value
=
pool_process
.
use_x
?
output_data
[
output_sub_idx
]
:
static_cast
<
T
>
(
0
);
pool_process
.
compute
(
input
,
ouput_value
,
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
input_grad_data
);
}
}
}
else
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
int
pool_size
=
ksize_height
*
ksize_width
;
int
tmp_idx
=
ph
*
output_width
+
pw
;
int
output_sub_idx
=
channel_last
?
tmp_idx
*
divmods
.
channel
.
divisor
+
c_offset
:
tmp_idx
;
T
ouput_value
=
pool_process
.
use_x
?
output_data
[
output_sub_idx
]
:
static_cast
<
T
>
(
0
);
pool_process
.
compute
(
input
,
ouput_value
,
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
input_grad_data
);
}
}
int
output_sub_idx
=
channel_last
?
(
ph
*
output_width
+
pw
)
*
channels
+
offsetC
:
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
gradient
);
}
}
input_grad
[
index
]
=
gradient
;
input_grad
[
index
]
=
input_grad_data
;
}
}
...
...
@@ -180,45 +273,32 @@ __global__ void KernelMaxPool2DGrad(
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
T
*
input_grad
,
bool
channel_last
=
false
)
{
T
*
input_grad
,
FastDivModForPooling
divmods
,
bool
channel_last
=
false
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
,
ph
,
c
,
batch_idx
;
if
(
!
channel_last
)
{
/* NCHW */
pw
=
index
%
output_width
;
ph
=
(
index
/
output_width
)
%
output_height
;
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
batch_idx
=
index
/
output_width
/
output_height
/
channels
;
}
else
{
/* NHWC */
c
=
index
%
channels
;
pw
=
(
index
/
channels
)
%
output_width
;
ph
=
(
index
/
channels
/
output_width
)
%
output_height
;
batch_idx
=
index
/
channels
/
output_width
/
output_height
;
}
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
w_offset
,
h_offset
,
c_offset
,
input_offset
;
OffsetPreparationFor4Dimension
<
FastDivModForPooling
>
(
index
,
channel_last
,
divmods
,
0
,
0
,
input_width
,
input_height
,
&
w_offset
,
&
h_offset
,
&
c_offset
,
&
input_offset
);
input_data
+=
input_offset
;
input_grad
+=
input_offset
;
int
hstart
=
h_offset
*
stride_height
-
padding_height
;
int
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
max
(
hstart
,
0
);
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wstart
=
w_offset
*
stride_width
-
padding_width
;
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
max
(
wstart
,
0
);
int
input_stride
;
if
(
!
channel_last
)
{
input_stride
=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
}
else
{
input_stride
=
batch_idx
*
input_height
*
input_width
*
channels
;
}
input_data
+=
input_stride
;
input_grad
+=
input_stride
;
T
ele
=
output_data
[
index
];
int
maxIndex
=
-
1
;
bool
stop
=
false
;
for
(
int
h
=
hstart
;
h
<
hend
&&
!
stop
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
&&
!
stop
;
++
w
)
{
int
input_data_idx
=
channel_last
?
(
h
*
input_width
+
w
)
*
channels
+
c
:
h
*
input_width
+
w
;
int
input_data_idx
=
channel_last
?
(
h
*
input_width
+
w
)
*
channels
+
c_offset
:
h
*
input_width
+
w
;
if
(
ele
==
input_data
[
input_data_idx
])
{
maxIndex
=
input_data_idx
;
stop
=
true
;
...
...
@@ -264,10 +344,13 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
nthreads
,
input
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_compute
,
exclusive
,
adaptive
,
output
);
padding_height
,
padding_width
,
pool_divmods
,
pool_compute
,
exclusive
,
adaptive
,
output
);
}
/*
...
...
@@ -311,11 +394,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int
blocks
=
(
nthreads
+
thread_num
-
1
)
/
thread_num
;
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_
process
,
exclusive
,
adaptive
,
output_data
);
stride_width
,
padding_height
,
padding_width
,
pool_
divmods
,
pool_process
,
exclusive
,
adaptive
,
output_data
);
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
...
...
@@ -357,11 +443,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int
blocks
=
(
nthreads
+
thread_num
-
1
)
/
thread_num
;
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_
process
,
exclusive
,
adaptive
,
output_data
,
channel_last
);
stride_width
,
padding_height
,
padding_width
,
pool_
divmods
,
pool_process
,
exclusive
,
adaptive
,
output_data
,
channel_last
);
}
};
/*
...
...
@@ -402,15 +491,18 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2DGrad
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
adaptive
,
input_grad_data
);
int
blocks
=
GetThreadsPerBlock
(
context
,
POOLING_BLOCK_SIZE
,
nthreads
);
int
grids
=
(
nthreads
+
blocks
-
1
)
/
blocks
;
auto
pool_divmods
=
FastDivModForPoolingWithMoreStaff
(
input_channels
,
input_width
,
input_height
,
ksize_width
,
ksize_height
,
stride_width
,
stride_height
);
KernelPool2DGrad
<
T
,
PoolProcess
><<<
grids
,
blocks
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
output_width
,
output_height
,
input_width
,
input_height
,
ksize_width
,
ksize_height
,
stride_width
,
stride_height
,
padding_width
,
padding_height
,
pool_divmods
,
pool_process
,
exclusive
,
adaptive
,
input_grad_data
);
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
...
...
@@ -424,7 +516,6 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
bool
channel_last
=
(
data_format
==
"NHWC"
);
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
channel_last
?
input
.
dims
()[
3
]
:
input
.
dims
()[
1
];
const
int
input_height
=
channel_last
?
input
.
dims
()[
1
]
:
input
.
dims
()[
2
];
const
int
input_width
=
channel_last
?
input
.
dims
()[
2
]
:
input
.
dims
()[
3
];
...
...
@@ -447,19 +538,22 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2DGrad
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
adaptive
,
input_grad_data
,
channel_last
);
int
blocks
=
GetThreadsPerBlock
(
context
,
POOLING_BLOCK_SIZE
,
nthreads
);
int
grids
=
(
nthreads
+
blocks
-
1
)
/
blocks
;
auto
pool_divmods
=
FastDivModForPoolingWithMoreStaff
(
input_channels
,
input_width
,
input_height
,
ksize_width
,
ksize_height
,
stride_width
,
stride_height
);
KernelPool2DGrad
<
T
,
PoolProcess
><<<
grids
,
blocks
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
output_width
,
output_height
,
input_width
,
input_height
,
ksize_width
,
ksize_height
,
stride_width
,
stride_height
,
padding_width
,
padding_height
,
pool_divmods
,
pool_process
,
exclusive
,
adaptive
,
input_grad_data
,
channel_last
);
}
};
...
...
@@ -505,11 +599,13 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelMaxPool2DGrad
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
input_grad_data
);
input_grad_data
,
pool_divmods
);
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
...
...
@@ -550,11 +646,14 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelMaxPool2DGrad
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
input_grad_data
,
channel_last
);
input_grad_data
,
pool_divmods
,
channel_last
);
}
};
...
...
@@ -689,35 +788,40 @@ __global__ void KernelPool3D(
}
}
template
<
typename
PoolProcess
,
typename
T
>
template
<
typename
T
,
typename
PoolProcess
>
__global__
void
KernelPool3DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
input_grad
,
bool
channel_last
=
false
)
{
const
int
nthreads
,
const
T
*
__restrict__
input_data
,
const
T
*
__restrict__
output_data
,
const
T
*
__restrict__
output_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
bool
adaptive
,
T
*
input_grad
,
bool
channel_last
=
false
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
,
h_offset
,
d_offset
,
offsetC
,
batch_idx
;
int
w_offset
,
h_offset
,
d_offset
,
c_offset
,
batch_idx
,
output_stride
;
T
input
=
static_cast
<
T
>
(
0
);
if
(
!
channel_last
)
{
/* "NCDHW" */
w_offset
=
index
%
input_width
+
padding_width
;
h_offset
=
(
index
/
input_width
)
%
input_height
+
padding_height
;
d_offset
=
(
index
/
input_width
/
input_height
)
%
input_depth
+
padding_depth
;
offsetC
=
(
index
/
input_width
/
input_height
/
input_depth
)
%
channels
;
c_offset
=
(
index
/
input_width
/
input_height
/
input_depth
)
%
channels
;
batch_idx
=
index
/
input_width
/
input_height
/
input_depth
/
channels
;
output_stride
=
(
batch_idx
*
channels
+
c_offset
)
*
output_depth
*
output_height
*
output_width
;
}
else
{
/* "NDHWC" */
offsetC
=
index
%
channels
;
c_offset
=
index
%
channels
;
w_offset
=
(
index
/
channels
)
%
input_width
+
padding_width
;
h_offset
=
(
index
/
channels
/
input_width
)
%
input_height
+
padding_height
;
d_offset
=
(
index
/
channels
/
input_width
/
input_height
)
%
input_depth
+
padding_depth
;
batch_idx
=
index
/
channels
/
input_width
/
input_height
/
input_depth
;
output_stride
=
batch_idx
*
output_depth
*
output_height
*
output_width
*
channels
;
}
int
pdstart
,
pdend
;
...
...
@@ -746,20 +850,12 @@ __global__ void KernelPool3DGrad(
phend
=
min
((
h_offset
)
/
stride_height
+
1
,
output_height
);
pwend
=
min
((
w_offset
)
/
stride_width
+
1
,
output_width
);
}
T
gradient
=
static_cast
<
T
>
(
0.0
);
T
input
=
input_data
[
index
];
int
output_stride
;
if
(
!
channel_last
)
{
output_stride
=
(
batch_idx
*
channels
+
offsetC
)
*
output_depth
*
output_height
*
output_width
;
}
else
{
output_stride
=
batch_idx
*
output_depth
*
output_height
*
output_width
*
channels
;
if
(
pool_process
.
use_x
)
{
input
=
input_data
[
index
];
output_data
+=
output_stride
;
}
output_data
+=
output_stride
;
output_grad
+=
output_stride
;
T
input_grad_data
=
static_cast
<
T
>
(
0.0
);
for
(
int
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
...
...
@@ -792,16 +888,17 @@ __global__ void KernelPool3DGrad(
int
output_sub_idx
=
channel_last
?
((
pd
*
output_height
+
ph
)
*
output_width
+
pw
)
*
channels
+
offsetC
c_offset
:
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
gradient
);
T
ouput_value
=
pool_process
.
use_x
?
output_data
[
output_sub_idx
]
:
static_cast
<
T
>
(
0
);
pool_process
.
compute
(
input
,
ouput_value
,
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
),
&
input_grad_data
);
}
}
}
input_grad
[
index
]
=
gradient
;
input_grad
[
index
]
=
input_grad_data
;
}
}
...
...
@@ -1088,7 +1185,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool3DGrad
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
KernelPool3DGrad
<
T
,
PoolProcess
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
...
...
@@ -1142,7 +1239,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool3DGrad
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
KernelPool3DGrad
<
T
,
PoolProcess
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
...
...
@@ -1315,33 +1412,33 @@ __global__ void KernelMaxPool2dWithIdx(
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
bool
adaptive
,
T1
*
output_data
,
T2
*
mask_data
)
{
const
int
padding_width
,
bool
adaptive
,
T1
*
output_data
,
T2
*
mask_data
,
FastDivModForPooling
divmods
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
int
batch_idx
=
index
/
output_width
/
output_height
/
channels
;
int
hstart
,
hend
,
wstart
,
wend
;
int
w_offset
,
h_offset
,
c_offset
,
input_offset
;
OffsetPreparationFor4Dimension
<
FastDivModForPooling
>
(
index
,
false
,
divmods
,
0
,
0
,
input_width
,
input_height
,
&
w_offset
,
&
h_offset
,
&
c_offset
,
&
input_offset
);
input_data
+=
input_offset
;
int
hstart
,
hend
;
int
wstart
,
wend
;
if
(
adaptive
)
{
hstart
=
AdaptStartIndex
(
ph
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
ph
,
input_height
,
output_height
);
hstart
=
AdaptStartIndex
(
h_offset
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
h_offset
,
input_height
,
output_height
);
wstart
=
AdaptStartIndex
(
pw
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
pw
,
input_width
,
output_width
);
wstart
=
AdaptStartIndex
(
w_offset
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
w_offset
,
input_width
,
output_width
);
}
else
{
hstart
=
ph
*
stride_height
-
padding_height
;
hstart
=
h_offset
*
stride_height
-
padding_height
;
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
max
(
hstart
,
0
);
wstart
=
pw
*
stride_width
-
padding_width
;
wstart
=
w_offset
*
stride_width
-
padding_width
;
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
max
(
wstart
,
0
);
}
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
T1
ele
=
-
FLT_MAX
;
int
max_index
=
-
1
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
@@ -1365,16 +1462,17 @@ __global__ void KernelMaxPool2DWithIdxGrad(
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
bool
adaptive
,
T1
*
input_grad
)
{
T1
*
input_grad
,
FastDivModForPooling
divmods
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
int
offsetC
=
(
index
/
input_width
/
input_height
)
%
channels
;
int
batch_idx
=
index
/
input_width
/
input_height
/
channels
;
int
phstart
,
phend
,
pwstart
,
pwend
;
int
w_offset
,
h_offset
,
c_offset
,
output_offset
;
OffsetPreparationFor4Dimension
<
FastDivModForPooling
>
(
index
,
false
,
divmods
,
0
,
0
,
output_width
,
output_height
,
&
w_offset
,
&
h_offset
,
&
c_offset
,
&
output_offset
);
mask_data
+=
output_offset
;
output_grad
+=
output_offset
;
int
phstart
,
phend
;
int
pwstart
,
pwend
;
if
(
adaptive
)
{
phstart
=
h_offset
*
output_height
/
input_height
;
phend
=
...
...
@@ -1396,20 +1494,15 @@ __global__ void KernelMaxPool2DWithIdxGrad(
pwend
=
min
((
w_offset
+
padding_width
)
/
stride_width
+
1
,
output_width
);
}
T1
gradient
=
0
;
T1
input_grad_data
=
0
;
int
input_current_featuremap_idx
=
h_offset
*
input_width
+
w_offset
;
int
output_idx
=
(
batch_idx
*
channels
+
offsetC
)
*
output_height
*
output_width
;
mask_data
+=
output_idx
;
output_grad
+=
output_idx
;
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
mask_data
[
ph
*
output_width
+
pw
]
==
input_current_featuremap_idx
)
gradient
+=
output_grad
[
ph
*
output_width
+
pw
];
input_grad_data
+=
output_grad
[
ph
*
output_width
+
pw
];
}
}
input_grad
[
index
]
=
gradient
;
input_grad
[
index
]
=
input_grad_data
;
}
}
...
...
@@ -1453,11 +1546,14 @@ class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
int
blocks
=
(
nthreads
+
thread_num
-
1
)
/
thread_num
;
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
output_width
,
output_height
);
KernelMaxPool2dWithIdx
<
T1
,
T2
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
adaptive
,
output_data
,
mask_data
);
mask_data
,
pool_divmods
);
}
};
...
...
@@ -1497,11 +1593,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
auto
pool_divmods
=
FastDivModForPooling
(
input_channels
,
input_width
,
input_height
);
KernelMaxPool2DWithIdxGrad
<
T1
,
T2
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
output_grad_data
,
mask_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
adaptive
,
input_grad_data
);
input_grad_data
,
pool_divmods
);
}
};
...
...
@@ -1590,7 +1688,8 @@ __global__ void KernelMaxPool3DWithIdxGrad(
int
w_offset
=
index
%
input_width
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
int
d_offset
=
(
index
/
input_width
/
input_height
)
%
input_depth
;
int
offsetC
=
(
index
/
input_width
/
input_height
/
input_depth
)
%
channels
;
int
c_offset
=
(
index
/
input_width
/
input_height
/
input_depth
)
%
channels
;
int
batch_idx
=
index
/
input_width
/
input_height
/
input_depth
/
channels
;
int
pdstart
,
pdend
;
...
...
@@ -1625,10 +1724,10 @@ __global__ void KernelMaxPool3DWithIdxGrad(
pwend
=
min
((
w_offset
+
padding_width
)
/
stride_width
+
1
,
output_width
);
}
T1
gradient
=
0
;
T1
input_grad_data
=
0
;
int
input_current_feature_map_idx
=
(
d_offset
*
input_height
+
h_offset
)
*
input_width
+
w_offset
;
int
output_idx
=
(
batch_idx
*
channels
+
offsetC
)
*
output_depth
*
int
output_idx
=
(
batch_idx
*
channels
+
c_offset
)
*
output_depth
*
output_height
*
output_width
;
mask
+=
output_idx
;
output_grad
+=
output_idx
;
...
...
@@ -1638,12 +1737,12 @@ __global__ void KernelMaxPool3DWithIdxGrad(
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
mask
[(
pd
*
output_height
+
ph
)
*
output_width
+
pw
]
==
input_current_feature_map_idx
)
gradient
+=
input_grad_data
+=
output_grad
[(
pd
*
output_height
+
ph
)
*
output_width
+
pw
];
}
}
}
input_grad
[
index
]
=
gradient
;
input_grad
[
index
]
=
input_grad_data
;
}
}
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
287ca7d5
...
...
@@ -68,8 +68,9 @@ class AvgPool {
template
<
class
T
>
class
MaxPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
static
constexpr
bool
use_x
=
true
;
HOSTDEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
dy
*
static_cast
<
T
>
(
x
==
y
);
}
};
...
...
@@ -77,8 +78,9 @@ class MaxPoolGrad {
template
<
class
T
>
class
AvgPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
static
constexpr
bool
use_x
=
false
;
HOSTDEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
(
scale
*
dy
);
}
};
...
...
paddle/fluid/operators/optimizers/lars_momentum_op.cc
浏览文件 @
287ca7d5
...
...
@@ -13,46 +13,158 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace
paddle
{
namespace
operators
{
class
LarsMomentumOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"Param"
),
"Input"
,
"Param"
,
"LarsMomentum"
);
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"Grad"
),
"Input"
,
"Grad"
,
"LarsMomentum"
);
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"Velocity"
),
"Input"
,
"Velocity"
,
"LarsMomentum"
);
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"LearningRate"
),
"Input"
,
"LearningRate"
,
"LarsMomentum"
);
OP_INOUT_CHECK
(
ctx
->
HasOutputs
(
"ParamOut"
),
"Output"
,
"ParamOut"
,
"LarsMomentum"
);
OP_INOUT_CHECK
(
ctx
->
HasOutputs
(
"VelocityOut"
),
"Output"
,
"VelocityOut"
,
"LarsMomentum"
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputsVarType
(
"Param"
).
front
(),
framework
::
proto
::
VarType
::
LOD_TENSOR
,
platform
::
errors
::
InvalidArgument
(
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
GetInputsVarType
(
"Param"
).
front
()));
auto
lr_dims
=
ctx
->
GetInputsDim
(
"LearningRate"
);
auto
grad_dim
=
ctx
->
GetInputsDim
(
"Grad"
);
auto
param_dim
=
ctx
->
GetInputsDim
(
"Param"
);
auto
velocity_dim
=
ctx
->
GetInputsDim
(
"Velocity"
);
auto
lars_weight_decays
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
"lars_weight_decay"
);
auto
multi_precision
=
ctx
->
Attrs
().
Get
<
bool
>
(
"multi_precision"
);
PADDLE_ENFORCE_EQ
(
param_dim
.
size
(),
grad_dim
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d]."
,
param_dim
.
size
(),
grad_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
param_dim
.
size
(),
velocity_dim
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d]."
,
param_dim
.
size
(),
velocity_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
lars_weight_decays
.
size
(),
grad_dim
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d]."
,
lars_weight_decays
.
size
(),
grad_dim
.
size
()));
if
(
multi_precision
)
{
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"MasterParam"
),
"Input"
,
"MasterParam"
,
"LarsMomentumMultiPrecision"
);
OP_INOUT_CHECK
(
ctx
->
HasOutputs
(
"MasterParamOut"
),
"Output"
,
"MasterParamOut"
,
"LarsMomentumMultiPrecision"
);
}
for
(
size_t
i
=
0
;
i
<
lr_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dims
[
i
]),
1
,
platform
::
errors
::
InvalidArgument
(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]"
,
framework
::
product
(
lr_dims
[
i
])));
}
for
(
size_t
i
=
0
;
i
<
param_dim
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputsVarType
(
"Grad"
)[
i
],
framework
::
proto
::
VarType
::
LOD_TENSOR
,
platform
::
errors
::
InvalidArgument
(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
->
Inputs
(
"Grad"
)[
i
].
front
(),
ctx
->
GetInputsVarType
(
"Grad"
)[
i
]));
PADDLE_ENFORCE_EQ
(
param_dim
[
i
],
grad_dim
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s]."
,
param_dim
[
i
],
grad_dim
[
i
]));
PADDLE_ENFORCE_EQ
(
param_dim
[
i
],
velocity_dim
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s]."
,
param_dim
[
i
],
velocity_dim
[
i
]));
}
ctx
->
SetOutputsDim
(
"ParamOut"
,
param_dim
);
ctx
->
SetOutputsDim
(
"VelocityOut"
,
param_dim
);
if
(
ctx
->
HasOutputs
(
"MasterParamOut"
))
{
ctx
->
SetOutputsDim
(
"MasterParamOut"
,
param_dim
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
LarsMomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Param"
,
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated"
);
"Input parameter that has to be updated"
)
.
AsDuplicable
();
AddInput
(
"Grad"
,
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter"
);
"Input gradient of the parameter"
)
.
AsDuplicable
();
AddInput
(
"Velocity"
,
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated"
);
"that has to be updated"
)
.
AsDuplicable
();
AddInput
(
"LearningRate"
,
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate"
);
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
"Input learning rate"
)
.
AsDuplicable
();
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
)
.
AsDuplicable
()
.
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param)."
);
"It shared memory with Input(Param)."
)
.
AsDuplicable
();
AddOutput
(
"VelocityOut"
,
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity)."
);
"It shared memory with Input(Velocity)."
)
.
AsDuplicable
();
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDuplicable
()
.
AsDispensable
();
AddAttr
<
float
>
(
"mu"
,
"(float) Momentum coefficient"
);
AddAttr
<
float
>
(
"lars_coeff"
,
"(float, default 0.001) LARS coefficient."
)
.
SetDefault
(
0.001
);
AddAttr
<
float
>
(
"lars_weight_decay"
,
"(float, default 0.0005) LARS weight decay"
)
.
SetDefault
(
0.0005
);
AddAttr
<
std
::
vector
<
float
>>
(
"lars_weight_decay"
,
"(std::vector<float>, default 0.0005) LARS weight decay params"
)
.
SetDefault
({
0.0005
});
AddAttr
<
float
>
(
"epsilon"
,
"(float, default 0.0) epsilon to avoid Division by Zero."
)
.
SetDefault
(
0.0
);
...
...
@@ -68,10 +180,8 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
Lars Momentum Optimizer.
This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate:
$$
local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
...
...
@@ -79,10 +189,8 @@ velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\
param = param - velocity. \\
$$
Note that we use lars_weight_decay here to decay weights, you may need not to
use L2 regularizers in case of using LARS.
)DOC"
);
}
};
...
...
@@ -96,7 +204,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
lars_momentum
,
ops
::
MomentumOp
,
ops
::
LarsMomentumOpMaker
,
lars_momentum
,
ops
::
Lars
MomentumOp
,
ops
::
LarsMomentumOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
LarsMomentumOpVarTypeInference
);
...
...
paddle/fluid/operators/optimizers/lars_momentum_op.cu
浏览文件 @
287ca7d5
...
...
@@ -14,7 +14,21 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/platform/fast_divmod.h"
#if CUDA_VERSION >= 11000
#include <cooperative_groups.h>
#endif
#ifdef __HIPCC__
#define LARS_BLOCK_SIZE 256
#else
#define LARS_BLOCK_SIZE 512
#endif
#define LARS_MAX_MERGED_OPS 60
namespace
paddle
{
namespace
operators
{
...
...
@@ -22,124 +36,472 @@ namespace operators {
template
<
typename
T
>
using
MultiPrecisionType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
template
<
typename
T
,
typename
MT
>
__global__
void
MomentumLarsKernel
(
const
T
*
p
,
const
T
*
g
,
const
MT
*
v
,
const
MultiPrecisionType
<
T
>*
learning_rate
,
const
MT
mu
,
const
int64_t
num
,
const
MT
lars_coeff
,
const
MT
lars_weight_decay
,
const
MultiPrecisionType
<
T
>*
p_norm
,
const
MultiPrecisionType
<
T
>*
g_norm
,
T
*
p_out
,
MT
*
v_out
,
const
MT
epsilon
,
const
MT
*
master_p
,
MT
*
master_p_out
,
const
MultiPrecisionType
<
T
>
rescale_grad
)
{
const
MT
lr
=
static_cast
<
MT
>
(
learning_rate
[
0
]);
MT
local_lr
=
lr
;
const
MT
p_n
=
static_cast
<
MT
>
(
p_norm
[
0
]);
const
MT
g_n
=
static_cast
<
MT
>
(
g_norm
[
0
]);
__device__
__forceinline__
float
Sqrt
(
float
x
)
{
return
sqrtf
(
x
);
}
__device__
__forceinline__
double
Sqrt
(
double
x
)
{
return
sqrt
(
x
);
}
__device__
__forceinline__
float
Fma
(
float
x
,
float
y
,
float
z
)
{
return
fmaf
(
x
,
y
,
z
);
}
__device__
__forceinline__
double
Fma
(
double
x
,
double
y
,
double
z
)
{
return
fma
(
x
,
y
,
z
);
}
template
<
typename
T
>
class
LarsThreadConfig
{
public:
int
grid_for_norm
;
int
grid_for_lars
;
#if CUDA_VERSION >= 11000
if
(
lars_weight_decay
>
static_cast
<
MT
>
(
0
)
&&
p_n
>
static_cast
<
MT
>
(
0
)
&&
g_n
>
static_cast
<
MT
>
(
0
))
{
local_lr
=
lr
*
lars_coeff
*
p_n
/
(
g_n
+
lars_weight_decay
*
p_n
+
epsilon
);
private:
int
grid_stride
;
public:
explicit
LarsThreadConfig
(
int64_t
numel
,
int
sm_num
,
int
num_blocks_per_sm
)
{
int
grid
=
(
numel
+
LARS_BLOCK_SIZE
-
1
)
/
LARS_BLOCK_SIZE
;
grid_for_lars
=
std
::
min
(
std
::
min
(
sm_num
*
num_blocks_per_sm
,
grid
),
LARS_BLOCK_SIZE
);
grid_stride
=
LARS_BLOCK_SIZE
*
grid_for_lars
;
}
CUDA_KERNEL_LOOP
(
i
,
num
)
{
MT
grad
=
static_cast
<
MT
>
(
g
[
i
])
*
static_cast
<
MT
>
(
rescale_grad
);
MT
param
=
master_p
?
master_p
[
i
]
:
static_cast
<
MT
>
(
p
[
i
]);
MT
v_new
=
v
[
i
]
*
mu
+
local_lr
*
(
grad
+
lars_weight_decay
*
param
);
MT
p_new
=
param
-
v_new
;
int
GetRepeatTimes
(
int64_t
numel
)
{
return
(
numel
+
grid_stride
-
1
)
/
grid_stride
-
1
;
}
#else
int
repeat_times
;
explicit
LarsThreadConfig
(
const
int64_t
numel
)
{
int
grid
=
(
numel
+
LARS_BLOCK_SIZE
-
1
)
/
LARS_BLOCK_SIZE
;
grid_for_norm
=
std
::
min
(
grid
,
LARS_BLOCK_SIZE
);
const
int
grid_stride
=
grid_for_norm
*
LARS_BLOCK_SIZE
;
repeat_times
=
(
numel
+
grid_stride
-
1
)
/
grid_stride
-
1
;
// Determine to read 4 fp16 or float data once, but 2 double data once.
grid_for_lars
=
std
::
is_same
<
double
,
T
>::
value
?
(
numel
+
(
LARS_BLOCK_SIZE
<<
1
)
-
1
)
/
(
LARS_BLOCK_SIZE
<<
1
)
:
(
numel
+
(
LARS_BLOCK_SIZE
<<
2
)
-
1
)
/
(
LARS_BLOCK_SIZE
<<
2
);
}
#endif
};
template
<
typename
T
,
typename
MT
,
int
VecSize
,
bool
IsAmp
=
false
>
__device__
inline
void
VectorizeLarsUpdate
(
const
T
*
__restrict__
grad
,
const
MT
*
param
,
const
MT
*
velocity
,
T
*
param_out
,
MT
*
velocity_out
,
const
MT
mu
,
MT
local_lr
,
const
MT
lars_weight_decay
,
const
MT
rescale_grad
,
const
int
tid
,
const
int
grid_stride
,
const
int
numel
,
MT
*
master_param_out
=
nullptr
)
{
using
VecType
=
paddle
::
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
VecMType
=
paddle
::
platform
::
AlignedVector
<
MT
,
VecSize
>
;
int
main
=
numel
>>
(
VecSize
>>
1
);
int
tail_offset
=
main
*
VecSize
;
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
static_cast
<
T
>
(
p_new
);
if
(
master_p_out
)
master_p_out
[
i
]
=
p_new
;
const
VecType
*
grad_vec
=
reinterpret_cast
<
const
VecType
*>
(
grad
);
const
VecMType
*
param_vec
=
reinterpret_cast
<
const
VecMType
*>
(
param
);
const
VecMType
*
velocity_vec
=
reinterpret_cast
<
const
VecMType
*>
(
velocity
);
VecType
*
param_out_vec
=
reinterpret_cast
<
VecType
*>
(
param_out
);
VecMType
*
velocity_out_vec
=
reinterpret_cast
<
VecMType
*>
(
velocity_out
);
VecMType
*
master_param_out_vec
;
if
(
IsAmp
)
{
master_param_out_vec
=
reinterpret_cast
<
VecMType
*>
(
master_param_out
);
}
for
(
int
i
=
tid
;
i
<
main
;
i
+=
grid_stride
)
{
VecType
param_out_tmp
;
VecMType
velocity_tmp
,
param_tmp
;
VecType
grad_data
=
grad_vec
[
i
];
VecMType
param_data
=
param_vec
[
i
];
VecMType
velocity_data
=
velocity_vec
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
++
j
)
{
MT
grad_val
=
static_cast
<
MT
>
(
grad_data
[
j
])
*
rescale_grad
;
velocity_tmp
[
j
]
=
Fma
(
velocity_data
[
j
],
mu
,
local_lr
*
Fma
(
lars_weight_decay
,
param_data
[
j
],
grad_val
));
param_tmp
[
j
]
=
param_data
[
j
]
-
velocity_tmp
[
j
];
param_out_tmp
[
j
]
=
static_cast
<
T
>
(
param_tmp
[
j
]);
}
param_out_vec
[
i
]
=
param_out_tmp
;
velocity_out_vec
[
i
]
=
velocity_tmp
;
if
(
IsAmp
)
{
master_param_out_vec
[
i
]
=
param_tmp
;
}
}
for
(
int
i
=
tid
+
tail_offset
;
i
<
numel
;
i
+=
grid_stride
)
{
MT
grad_val
=
static_cast
<
MT
>
(
grad
[
i
])
*
rescale_grad
;
MT
param_val
=
param
[
i
];
MT
velocity_tmp
=
Fma
(
velocity
[
i
],
mu
,
local_lr
*
Fma
(
lars_weight_decay
,
param_val
,
grad_val
));
MT
param_tmp
=
param_val
-
velocity_tmp
;
param_out
[
i
]
=
static_cast
<
T
>
(
param_tmp
);
velocity_out
[
i
]
=
velocity_tmp
;
if
(
IsAmp
)
{
master_param_out
[
i
]
=
param_tmp
;
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
LarsMomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
using
MPDType
=
MultiPrecisionType
<
T
>
;
#if CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without
--rdc=true compile flag, then L2_norm kernel can be set with __device__ and
cooperative_groups::grid_group also can be involved. Otherwise, adding this
flag may affect much, L2_norm kernel shall be set with __global__.*/
// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host.
template
<
typename
T
,
typename
MT
>
__forceinline__
__device__
void
L2NormKernel
(
const
cooperative_groups
::
grid_group
*
cg
,
#else
template
<
typename
T
,
typename
MT
>
__global__
void
L2NormKernel
(
#endif
const
T
*
p_data
,
const
T
*
__restrict__
g_data
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
g_buffer
,
const
int64_t
numel
,
const
int
repeat_times
,
const
MT
rescale_grad
,
const
int
thresh
=
0
,
MT
*
__restrict__
p_n
=
nullptr
,
MT
*
__restrict__
g_n
=
nullptr
)
{
__shared__
MT
s_buffer
[
2
];
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
grid_stride
=
LARS_BLOCK_SIZE
*
gridDim
.
x
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
if
(
multi_precision
)
{
InnerCompute
<
MPDType
>
(
ctx
,
multi_precision
);
MT
p_tmp
=
static_cast
<
MT
>
(
0
);
MT
g_tmp
=
static_cast
<
MT
>
(
0
);
while
(
tid
<
numel
)
{
MT
tmp0
=
static_cast
<
MT
>
(
p_data
[
tid
]);
MT
tmp1
=
static_cast
<
MT
>
(
g_data
[
tid
]);
p_tmp
+=
(
tmp0
*
tmp0
);
g_tmp
+=
(
tmp1
*
tmp1
);
tid
+=
grid_stride
;
}
p_tmp
=
math
::
blockReduceSum
<
MT
>
(
p_tmp
,
FINAL_MASK
);
g_tmp
=
math
::
blockReduceSum
<
MT
>
(
g_tmp
,
FINAL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
p_buffer
[
blockIdx
.
x
]
=
p_tmp
;
g_buffer
[
blockIdx
.
x
]
=
g_tmp
;
}
#if CUDA_VERSION >= 11000
cg
->
sync
();
// Grid sync for writring partial result to gloabl memory
MT
p_part_sum
=
threadIdx
.
x
<
gridDim
.
x
?
p_buffer
[
threadIdx
.
x
]
:
0
;
MT
g_part_sum
=
threadIdx
.
x
<
gridDim
.
x
?
g_buffer
[
threadIdx
.
x
]
:
0
;
MT
tmp0
=
math
::
blockReduceSum
<
MT
>
(
p_part_sum
,
FINAL_MASK
);
MT
tmp1
=
math
::
blockReduceSum
<
MT
>
(
g_part_sum
,
FINAL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
s_buffer
[
0
]
=
tmp0
;
s_buffer
[
1
]
=
tmp1
;
}
__syncthreads
();
*
p_n
=
Sqrt
(
s_buffer
[
0
]);
*
g_n
=
rescale_grad
*
Sqrt
(
s_buffer
[
1
]);
#endif
}
template
<
typename
T
,
typename
MT
>
__forceinline__
__device__
void
MomentumUpdate
(
const
T
*
param
,
const
T
*
__restrict__
grad
,
const
MT
*
velocity
,
T
*
param_out
,
MT
*
velocity_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
const
MT
*
__restrict__
learning_rate
,
const
MT
mu
,
const
MT
lars_weight_decay
,
const
MT
lars_coeff
,
const
MT
epsilon
,
const
MT
rescale_grad
,
const
MT
param_norm
,
const
MT
grad_norm
,
const
int
tid
,
const
int
grid_stride
,
const
int64_t
numel
,
const
bool
is_amp
)
{
const
MT
lr
=
learning_rate
[
0
];
MT
local_lr
=
lr
;
if
(
lars_weight_decay
>
static_cast
<
MT
>
(
0
))
{
local_lr
=
lr
*
lars_coeff
*
param_norm
/
(
fma
(
lars_weight_decay
,
param_norm
,
grad_norm
)
+
epsilon
);
}
if
(
is_amp
)
{
VectorizeLarsUpdate
<
T
,
MT
,
/*VecSize=*/
4
,
/*IsAmp=*/
true
>
(
grad
,
master_param
,
velocity
,
param_out
,
velocity_out
,
mu
,
local_lr
,
lars_weight_decay
,
rescale_grad
,
tid
,
grid_stride
,
numel
,
master_param_out
);
}
else
{
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
/* TODO(limingshu): pointer cast may damage memory accessing for fp16 */
VectorizeLarsUpdate
<
T
,
MT
,
/*VecSize=*/
4
,
/*IsAmp=*/
false
>
(
grad
,
reinterpret_cast
<
const
MT
*>
(
param
),
velocity
,
param_out
,
velocity_out
,
mu
,
local_lr
,
lars_weight_decay
,
rescale_grad
,
tid
,
grid_stride
,
numel
);
}
else
{
InnerCompute
<
T
>
(
ctx
,
multi_precision
);
VectorizeLarsUpdate
<
T
,
MT
,
/*VecSize=*/
2
,
/*IsAmp=*/
false
>
(
grad
,
reinterpret_cast
<
const
MT
*>
(
param
),
velocity
,
param_out
,
velocity_out
,
mu
,
local_lr
,
lars_weight_decay
,
rescale_grad
,
tid
,
grid_stride
,
numel
);
}
}
}
private:
template
<
typename
MT
>
void
InnerCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
bool
multi_precision
)
const
{
auto
param_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"LearningRate"
);
const
framework
::
Tensor
*
master_param
=
nullptr
;
framework
::
Tensor
*
master_param_out
=
nullptr
;
if
(
multi_precision
)
{
bool
has_master
=
ctx
.
HasInput
(
"MasterParam"
)
&&
ctx
.
HasOutput
(
"MasterParamOut"
);
PADDLE_ENFORCE_EQ
(
has_master
,
true
,
platform
::
errors
::
InvalidArgument
(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"
));
master_param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"MasterParam"
);
master_param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MasterParamOut"
);
}
#if CUDA_VERSION >= 11000
template
<
typename
T
,
typename
MT
>
struct
LarsParamWarpper
{
int64_t
numel_arr
[
LARS_MAX_MERGED_OPS
];
int
repeat_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
g_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
lr_arr
[
LARS_MAX_MERGED_OPS
];
T
*
__restrict__
p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
v_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
master_p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
weight_decay_arr
[
LARS_MAX_MERGED_OPS
];
};
const
MT
*
master_p
=
multi_precision
?
master_param
->
data
<
MT
>
()
:
nullptr
;
MT
*
master_p_out
=
multi_precision
?
master_param_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
())
:
nullptr
;
template
<
typename
T
,
typename
MT
>
__global__
void
MergedMomentumLarsKernel
(
LarsParamWarpper
<
T
,
MT
>
lars_warpper
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
g_buffer
,
const
int
op_num
,
const
MT
mu
,
const
MT
lars_coeff
,
const
MT
epsilon
,
const
MT
rescale_grad
,
const
bool
is_amp
)
{
int
grid_stride
=
gridDim
.
x
*
LARS_BLOCK_SIZE
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
int
numel
=
lars_warpper
.
numel_arr
[
i
];
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
L2NormKernel
<
T
,
MT
>
(
&
cg
,
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
p_buffer
,
g_buffer
,
numel
,
lars_warpper
.
repeat_arr
[
i
],
rescale_grad
,
0
,
&
param_norm
,
&
grad_norm
);
MomentumUpdate
<
T
,
MT
>
(
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
master_p_out_arr
[
i
],
lars_warpper
.
master_p_out_arr
[
i
],
lars_warpper
.
lr_arr
[
i
],
mu
,
lars_warpper
.
weight_decay_arr
[
i
],
lars_coeff
,
epsilon
,
rescale_grad
,
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
}
}
#endif
T
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
MT
*
v_out
=
velocity_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
template
<
typename
T
,
typename
MT
>
__global__
void
MomentumLarsKernel
(
const
T
*
param
,
const
T
*
__restrict__
grad
,
const
MT
*
velocity
,
T
*
param_out
,
MT
*
velocity_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
const
MT
*
__restrict__
learning_rate
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
g_buffer
,
const
MT
mu
,
const
MT
lars_coeff
,
const
MT
lars_weight_decay
,
const
MT
epsilon
,
const
MT
rescale_grad
,
const
int
repeat_times
,
const
int
thresh
,
const
int64_t
numel
,
const
bool
is_amp
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
grid_stride
=
gridDim
.
x
*
LARS_BLOCK_SIZE
;
#if CUDA_VERSION >= 11000
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
L2NormKernel
<
T
,
MT
>
(
&
cg
,
param
,
grad
,
p_buffer
,
g_buffer
,
numel
,
repeat_times
,
rescale_grad
,
gridDim
.
x
,
&
param_norm
,
&
grad_norm
);
#else
const
MT
rescale_grad_pow
=
rescale_grad
*
rescale_grad
;
MT
param_part_norm
=
threadIdx
.
x
<
thresh
?
p_buffer
[
threadIdx
.
x
]
:
0
;
MT
grad_part_norm
=
threadIdx
.
x
<
thresh
?
g_buffer
[
threadIdx
.
x
]
:
0
;
__syncthreads
();
MT
param_norm
=
Sqrt
(
math
::
blockReduceSum
<
MT
>
(
param_part_norm
,
FINAL_MASK
));
MT
grad_norm
=
Sqrt
(
rescale_grad_pow
*
math
::
blockReduceSum
<
MT
>
(
grad_part_norm
,
FINAL_MASK
));
#endif
MomentumUpdate
<
T
,
MT
>
(
param
,
grad
,
velocity
,
param_out
,
velocity_out
,
master_param
,
master_param_out
,
learning_rate
,
mu
,
lars_weight_decay
,
lars_coeff
,
epsilon
,
rescale_grad
,
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
}
template
<
typename
T
,
typename
MT
>
inline
void
SeparatedLarsMomentumOpCUDAKernel
(
const
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
T
*
param_data
,
T
*
param_out_data
,
const
MT
*
velocity_data
,
MT
*
velocity_out_data
,
const
T
*
grad_data
,
const
MT
*
lr
,
MT
*
p_buffer
,
MT
*
g_buffer
,
const
MT
mu
,
const
MT
lars_coeff
,
const
MT
weight_decay
,
const
MT
epsilon
,
const
MT
rescale_grad
,
const
int64_t
numel
,
const
MT
*
master_param_data
,
MT
*
master_out_data
,
const
bool
is_amp
)
{
LarsThreadConfig
<
T
>
lars_thread_config
(
numel
);
L2NormKernel
<
T
,
MT
><<<
lars_thread_config
.
grid_for_norm
,
LARS_BLOCK_SIZE
,
0
,
cuda_ctx
.
stream
()
>>>
(
param_data
,
grad_data
,
p_buffer
,
g_buffer
,
numel
,
lars_thread_config
.
repeat_times
,
rescale_grad
);
MomentumLarsKernel
<
T
,
MT
><<<
lars_thread_config
.
grid_for_lars
,
LARS_BLOCK_SIZE
,
0
,
cuda_ctx
.
stream
()
>>>
(
param_data
,
grad_data
,
velocity_data
,
param_out_data
,
velocity_out_data
,
master_param_data
,
master_out_data
,
lr
,
p_buffer
,
g_buffer
,
mu
,
lars_coeff
,
weight_decay
,
epsilon
,
rescale_grad
,
0
,
lars_thread_config
.
grid_for_norm
,
numel
,
is_amp
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
LarsMomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
using
MT
=
MultiPrecisionType
<
T
>
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
num_blocks_per_sm
=
0
;
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
sm_num
=
cuda_ctx
.
GetSMCount
();
framework
::
Tensor
tmp_buffer_t
=
ctx
.
AllocateTmpTensor
<
MT
,
platform
::
CUDADeviceContext
>
(
{
LARS_BLOCK_SIZE
<<
1
},
cuda_ctx
);
auto
*
p_buffer
=
tmp_buffer_t
.
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
auto
*
g_buffer
=
p_buffer
+
LARS_BLOCK_SIZE
;
MT
mu
=
static_cast
<
MT
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
MT
lars_coeff
=
static_cast
<
MT
>
(
ctx
.
Attr
<
float
>
(
"lars_coeff"
));
MT
lars_weight_decay
=
static_cast
<
MT
>
(
ctx
.
Attr
<
float
>
(
"lars_weight_decay"
));
MT
epsilon
=
static_cast
<
MT
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
MPDType
rescale_grad
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"rescale_grad"
));
auto
*
p
=
param
->
data
<
T
>
();
auto
*
g
=
grad
->
data
<
T
>
();
auto
*
v
=
velocity
->
data
<
MT
>
();
auto
*
lr
=
learning_rate
->
data
<
MPDType
>
();
int
block
=
512
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
auto
eigen_p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
eigen_g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
// calculate norms using eigein and launch the kernel.
framework
::
Tensor
p_norm_t
,
g_norm_t
;
p_norm_t
.
Resize
({
1
});
g_norm_t
.
Resize
({
1
});
auto
*
p_norm_data
=
p_norm_t
.
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
());
auto
*
g_norm_data
=
g_norm_t
.
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
());
auto
ep_norm
=
framework
::
EigenScalar
<
MPDType
>::
From
(
p_norm_t
);
auto
eg_norm
=
framework
::
EigenScalar
<
MPDType
>::
From
(
g_norm_t
);
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
// eigen unsupport fp16 l2-norm
ep_norm
.
device
(
*
place
)
=
eigen_p
.
template
cast
<
MPDType
>().
square
().
sum
().
sqrt
();
eg_norm
.
device
(
*
place
)
=
(
eigen_g
.
template
cast
<
MPDType
>()
*
rescale_grad
).
square
().
sum
().
sqrt
();
MomentumLarsKernel
<
T
,
MT
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
p
,
g
,
v
,
lr
,
mu
,
param
->
numel
(),
lars_coeff
,
lars_weight_decay
,
p_norm_data
,
g_norm_data
,
p_out
,
v_out
,
epsilon
,
master_p
,
master_p_out
,
rescale_grad
);
MT
rescale_grad
=
static_cast
<
MT
>
(
ctx
.
Attr
<
float
>
(
"rescale_grad"
));
auto
weight_decay_arr
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"lars_weight_decay"
);
auto
grad
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Grad"
);
auto
param
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Param"
);
auto
velocity
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Velocity"
);
auto
param_out
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"VelocityOut"
);
auto
learning_rate
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"LearningRate"
);
auto
master_param
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"MasterParam"
);
auto
master_param_out
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"MasterParamOut"
);
int
op_num
=
grad
.
size
();
#if CUDA_VERSION >= 11000
if
(
op_num
>
1
)
{
LarsParamWarpper
<
T
,
MT
>
lars_warpper
;
PADDLE_ENFORCE_LT
(
op_num
,
LARS_MAX_MERGED_OPS
,
platform
::
errors
::
InvalidArgument
(
"The maximum number of merged-ops supported is (%d), but"
"lars op required for trainning this model is (%d)
\n
"
,
LARS_MAX_MERGED_OPS
,
op_num
));
/* Implementation of lars optimizer consists of following two steps:
1. Figure out the L2 norm statistic result of grad data and param data.
2. Update param and velocity with usage of L2 norm statistic result.
Step1 and step2 can be merged with api provided by nvida
cudaLaunchCooperativeKernel:
- The thread quantity shall less than pyhsical SM limited threads
- Launche as thread-block can synchronizlly execute. */
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks_per_sm
,
MergedMomentumLarsKernel
<
T
,
MT
>
,
LARS_BLOCK_SIZE
,
sizeof
(
MT
)
<<
1
);
size_t
total_numel
=
0
;
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
size_t
temp_numel
=
param
[
i
]
->
numel
();
total_numel
+=
temp_numel
;
lars_warpper
.
numel_arr
[
i
]
=
temp_numel
;
lars_warpper
.
g_arr
[
i
]
=
grad
[
i
]
->
data
<
T
>
();
lars_warpper
.
lr_arr
[
i
]
=
learning_rate
[
i
]
->
data
<
MT
>
();
lars_warpper
.
p_out_arr
[
i
]
=
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
lars_warpper
.
v_out_arr
[
i
]
=
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
lars_warpper
.
weight_decay_arr
[
i
]
=
static_cast
<
MT
>
(
weight_decay_arr
[
i
]);
PADDLE_ENFORCE_EQ
(
param
[
i
]
->
data
<
T
>
(),
lars_warpper
.
p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) must be the same Tensors."
));
PADDLE_ENFORCE_EQ
(
velocity
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
v_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."
));
}
int64_t
avg_numel
=
total_numel
/
op_num
;
LarsThreadConfig
<
float
>
lars_thread_config
(
avg_numel
,
sm_num
,
num_blocks_per_sm
);
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
lars_warpper
.
repeat_arr
[
i
]
=
lars_thread_config
.
GetRepeatTimes
(
lars_warpper
.
numel_arr
[
i
]);
}
if
(
multi_precision
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
lars_warpper
.
master_p_out_arr
[
i
]
=
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
master_param
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
master_p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."
));
}
}
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
lars_warpper
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
op_num
),
reinterpret_cast
<
void
*>
(
&
mu
),
reinterpret_cast
<
void
*>
(
&
lars_coeff
),
reinterpret_cast
<
void
*>
(
&
epsilon
),
reinterpret_cast
<
void
*>
(
&
rescale_grad
),
reinterpret_cast
<
void
*>
(
&
multi_precision
)};
// Lanuch all sm theads, and thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
MergedMomentumLarsKernel
<
T
,
MT
>
),
lars_thread_config
.
grid_for_lars
,
LARS_BLOCK_SIZE
,
cuda_param
,
0
,
cuda_ctx
.
stream
());
}
else
{
auto
*
param_data
=
param
[
0
]
->
data
<
T
>
();
auto
*
grad_data
=
grad
[
0
]
->
data
<
T
>
();
auto
*
velocity_data
=
velocity
[
0
]
->
data
<
MT
>
();
auto
*
lr
=
learning_rate
[
0
]
->
data
<
MT
>
();
auto
*
param_out_data
=
param_out
[
0
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
velocity_out_data
=
velocity_out
[
0
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
const
MT
*
master_param_data
=
multi_precision
?
master_param
[
0
]
->
data
<
MT
>
()
:
nullptr
;
MT
*
master_param_out_data
=
multi_precision
?
master_param_out
[
0
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
())
:
nullptr
;
int64_t
numel
=
param
[
0
]
->
numel
();
MT
lars_weight_decay
=
weight_decay_arr
[
0
];
// Figure out how many blocks can be active in each sm.
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks_per_sm
,
MomentumLarsKernel
<
T
,
MT
>
,
LARS_BLOCK_SIZE
,
sizeof
(
MT
)
<<
1
);
LarsThreadConfig
<
float
>
lars_thread_config
(
numel
,
sm_num
,
num_blocks_per_sm
);
int
repeat_times
=
lars_thread_config
.
GetRepeatTimes
(
numel
);
int
thresh
=
0
;
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
param_data
),
reinterpret_cast
<
void
*>
(
&
grad_data
),
reinterpret_cast
<
void
*>
(
&
velocity_data
),
reinterpret_cast
<
void
*>
(
&
param_out_data
),
reinterpret_cast
<
void
*>
(
&
velocity_out_data
),
reinterpret_cast
<
void
*>
(
&
master_param_data
),
reinterpret_cast
<
void
*>
(
&
master_param_out_data
),
reinterpret_cast
<
void
*>
(
&
lr
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
mu
),
reinterpret_cast
<
void
*>
(
&
lars_coeff
),
reinterpret_cast
<
void
*>
(
&
lars_weight_decay
),
reinterpret_cast
<
void
*>
(
&
epsilon
),
reinterpret_cast
<
void
*>
(
&
rescale_grad
),
reinterpret_cast
<
void
*>
(
&
repeat_times
),
reinterpret_cast
<
void
*>
(
&
thresh
),
// Just a placeholder
reinterpret_cast
<
void
*>
(
&
numel
),
reinterpret_cast
<
void
*>
(
&
multi_precision
)};
// Lanuch all sm theads.
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
MomentumLarsKernel
<
T
,
MT
>
),
lars_thread_config
.
grid_for_lars
,
LARS_BLOCK_SIZE
,
cuda_param
,
0
,
cuda_ctx
.
stream
());
}
#else
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
const
MT
*
master_param_data
=
multi_precision
?
master_param
[
i
]
->
data
<
MT
>
()
:
nullptr
;
MT
*
master_param_out_data
=
multi_precision
?
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
())
:
nullptr
;
SeparatedLarsMomentumOpCUDAKernel
<
T
,
MT
>
(
cuda_ctx
,
param
[
i
]
->
data
<
T
>
(),
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity
[
i
]
->
data
<
MT
>
(),
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
()),
grad
[
i
]
->
data
<
T
>
(),
learning_rate
[
i
]
->
data
<
MT
>
(),
p_buffer
,
g_buffer
,
mu
,
lars_coeff
,
weight_decay_arr
[
i
],
epsilon
,
rescale_grad
,
param
[
i
]
->
numel
(),
master_param_data
,
master_param_out_data
,
multi_precision
);
}
#endif
}
};
...
...
paddle/fluid/operators/optimizers/lars_momentum_op.h
100755 → 100644
浏览文件 @
287ca7d5
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -23,54 +23,48 @@ template <typename T>
class
LarsMomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Velocity"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"LearningRate"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
// only support dense for now.
PADDLE_ENFORCE_EQ
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
InputNames
(
"Grad"
).
front
(),
framework
::
ToTypeName
(
grad_var
->
Type
())));
auto
grad
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Grad"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
param_out
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Param"
);
auto
velocity
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Velocity"
);
auto
learning_rate
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"LearningRate"
);
auto
grad
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Grad"
);
auto
weight_decay_arr
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"lars_weight_decay"
);
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
T
lars_coeff
=
ctx
.
Attr
<
float
>
(
"lars_coeff"
);
T
lars_weight_decay
=
ctx
.
Attr
<
float
>
(
"lars_weight_decay"
);
T
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
int
op_num
=
param
.
size
();
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
auto
*
lr
=
learning_rate
[
i
]
->
data
<
T
>
();
T
lars_weight_decay
=
weight_decay_arr
[
i
];
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
v
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
param_out
[
i
]));
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
velocity_out
[
i
]));
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
param
[
i
]));
auto
v
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
velocity
[
i
]));
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
grad
[
i
]));
framework
::
Tensor
p_norm_t
,
g_norm_t
;
p_norm_t
.
Resize
({
1
});
g_norm_t
.
Resize
({
1
});
p_norm_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
g_norm_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
ep_norm
=
framework
::
EigenScalar
<
T
>::
From
(
p_norm_t
);
auto
eg_norm
=
framework
::
EigenScalar
<
T
>::
From
(
g_norm_t
);
framework
::
Tensor
p_norm_t
,
g_norm_t
;
p_norm_t
.
Resize
({
1
});
g_norm_t
.
Resize
({
1
});
p_norm_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
g_norm_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
ep_norm
=
framework
::
EigenScalar
<
T
>::
From
(
p_norm_t
);
auto
eg_norm
=
framework
::
EigenScalar
<
T
>::
From
(
g_norm_t
);
ep_norm
=
p
.
square
().
sum
().
sqrt
();
eg_norm
=
g
.
square
().
sum
().
sqrt
();
ep_norm
=
p
.
square
().
sum
().
sqrt
();
eg_norm
=
g
.
square
().
sum
().
sqrt
();
T
local_lr
=
lr
[
0
];
if
(
lars_weight_decay
>
0
&&
ep_norm
(
0
)
>
0
&&
eg_norm
(
0
)
>
0
)
{
local_lr
=
lr
[
0
]
*
lars_coeff
*
ep_norm
(
0
)
/
(
eg_norm
(
0
)
+
lars_weight_decay
*
ep_norm
(
0
)
+
epsilon
);
T
local_lr
=
lr
[
0
];
if
(
lars_weight_decay
>
0
&&
ep_norm
(
0
)
>
0
&&
eg_norm
(
0
)
>
0
)
{
local_lr
=
lr
[
0
]
*
lars_coeff
*
ep_norm
(
0
)
/
(
eg_norm
(
0
)
+
lars_weight_decay
*
ep_norm
(
0
)
+
epsilon
);
}
v_out
=
v
*
mu
+
local_lr
*
(
g
+
lars_weight_decay
*
p
);
p_out
=
p
-
v_out
;
}
v_out
=
v
*
mu
+
local_lr
*
(
g
+
lars_weight_decay
*
p
);
p_out
=
p
-
v_out
;
}
};
...
...
paddle/fluid/operators/optimizers/merged_momentum_op.cc
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace
paddle
{
namespace
operators
{
class
MergedMomentumOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_dtype
=
framework
::
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
return
framework
::
OpKernelType
(
param_dtype
,
ctx
.
GetPlace
());
}
};
class
MergedMomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Param"
,
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated"
)
.
AsDuplicable
();
AddInput
(
"Grad"
,
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter"
)
.
AsDuplicable
();
AddInput
(
"Velocity"
,
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated"
)
.
AsDuplicable
();
AddInput
(
"LearningRate"
,
"(Tensor, default Tensor<float>) "
"Input learning rate"
);
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
)
.
AsDispensable
()
.
AsDuplicable
();
AddOutput
(
"ParamOut"
,
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param)."
)
.
AsDuplicable
();
AddOutput
(
"VelocityOut"
,
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity)."
)
.
AsDuplicable
();
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
()
.
AsDuplicable
();
AddAttr
<
float
>
(
"mu"
,
"(float) Momentum coefficient"
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
AddAttr
<
float
>
(
"rescale_grad"
,
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`."
)
.
SetDefault
(
1.0
f
);
AddComment
(
R"DOC(Merged Momentum Optimizer.)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
merged_momentum
,
ops
::
MergedMomentumOp
,
ops
::
MergedMomentumOpMaker
);
REGISTER_OP_CPU_KERNEL
(
merged_momentum
,
ops
::
MergedMomentumOpKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
MergedMomentumOpKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/optimizers/merged_momentum_op.cu
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
merged_momentum
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/optimizers/merged_momentum_op.h
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
MT
,
uint32_t
kParamNum
,
bool
kHasMasterParams
>
struct
MergedMomentumMasterParams
{
MT
*
PADDLE_RESTRICT
master_params
[
kParamNum
];
HOSTDEVICE
MT
*
MasterParam
(
size_t
idx
)
const
{
return
master_params
[
idx
];
}
HOSTDEVICE
void
SetMasterParam
(
size_t
idx
,
MT
*
p
)
{
master_params
[
idx
]
=
p
;
}
};
template
<
typename
MT
,
uint32_t
kParamNum
>
struct
MergedMomentumMasterParams
<
MT
,
kParamNum
,
false
>
{
HOSTDEVICE
constexpr
MT
*
MasterParam
(
size_t
)
const
{
return
nullptr
;
}
HOSTDEVICE
constexpr
void
SetMasterParam
(
size_t
,
MT
*
)
{}
};
template
<
typename
T
,
typename
MT
,
bool
kHasMasterParams
,
uint32_t
kParamNum
=
kHasMasterParams
?
55
:
110
>
struct
MergedMomentumKernelParam
:
public
MergedMomentumMasterParams
<
MT
,
kParamNum
,
kHasMasterParams
>
{
static
constexpr
auto
N
=
kParamNum
;
size_t
sizes
[
N
];
T
*
PADDLE_RESTRICT
params
[
N
];
const
T
*
PADDLE_RESTRICT
grads
[
N
];
MT
*
PADDLE_RESTRICT
velocitys
[
N
];
const
MT
*
PADDLE_RESTRICT
lr
;
MT
mu
;
MT
rescale_grad
;
uint32_t
param_num
;
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
const
auto
lr_val
=
*
lr
;
for
(
uint32_t
idx
=
0
;
idx
<
param_num
;
++
idx
)
{
auto
size
=
sizes
[
idx
];
if
(
i
>=
size
)
continue
;
auto
param_p
=
params
[
idx
];
auto
grad_p
=
grads
[
idx
];
auto
velocity_p
=
velocitys
[
idx
];
auto
master_param_p
=
this
->
MasterParam
(
idx
);
const
MT
param
=
master_param_p
?
master_param_p
[
i
]
:
static_cast
<
MT
>
(
param_p
[
i
]);
const
MT
grad
=
static_cast
<
MT
>
(
grad_p
[
i
])
*
rescale_grad
;
const
MT
velocity
=
velocity_p
[
i
];
const
MT
velocity_out
=
velocity
*
mu
+
grad
;
const
MT
param_out
=
param
-
lr_val
*
velocity_out
;
velocity_p
[
i
]
=
velocity_out
;
param_p
[
i
]
=
static_cast
<
T
>
(
param_out
);
if
(
master_param_p
)
{
master_param_p
[
i
]
=
param_out
;
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MergedMomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
params
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Param"
);
auto
params_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"ParamOut"
);
size_t
n
=
params
.
size
();
PADDLE_ENFORCE_EQ
(
n
,
params_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Output(ParamOut) number must be equal to Input(Param) number."
));
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
params
[
i
],
params_out
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) must be the same Tensors."
));
}
auto
grads
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Grad"
);
PADDLE_ENFORCE_EQ
(
n
,
grads
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(Grad) number must be equal to Input(Param) number."
));
auto
velocitys
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Velocity"
);
PADDLE_ENFORCE_EQ
(
n
,
velocitys
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(Velocity) number and Input(Param) number."
));
auto
velocitys_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"VelocityOut"
);
PADDLE_ENFORCE_EQ
(
n
,
velocitys_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Output(VelocityOut) number must be "
"equal to Input(Param) number."
));
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
velocitys
[
i
],
velocitys_out
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."
));
}
auto
master_params
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"MasterParam"
);
auto
master_params_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"MasterParamOut"
);
auto
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
if
(
multi_precision
)
{
PADDLE_ENFORCE_EQ
(
n
,
master_params
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) number must be "
"equal to Input(Param) number."
));
PADDLE_ENFORCE_EQ
(
n
,
master_params_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."
));
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
master_params
[
i
],
master_params_out
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."
));
PADDLE_ENFORCE_NOT_NULL
(
master_params
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) must be provided when "
"multi_precision=True."
));
}
}
else
{
master_params
.
clear
();
master_params_out
.
clear
();
}
auto
lr
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
mu
=
ctx
.
Attr
<
float
>
(
"mu"
);
auto
rescale_grad
=
ctx
.
Attr
<
float
>
(
"rescale_grad"
);
using
MPType
=
typename
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if
(
multi_precision
)
{
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
(
true
);
}
else
{
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
(
false
);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/optimizers/momentum_op.h
浏览文件 @
287ca7d5
...
...
@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
}
};
template
<
typename
T
,
typename
MT
,
typename
UpdateMethod
>
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
,
typename
UpdateMethod
>
class
DenseMomentumFunctor
;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template
<
typename
T
,
typename
MT
>
class
DenseMomentumFunctor
<
T
,
MT
,
UseNesterov
>
{
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
>
class
DenseMomentumFunctor
<
T
,
MT
,
kRegType
,
UseNesterov
>
{
private:
const
T
*
param_
;
const
T
*
grad_
;
...
...
@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T
*
param_out_
;
MT
*
velocity_out_
;
MT
*
master_param_out_
;
const
RegularizationType
regularization_flag_
;
const
MT
regularization_coeff_
;
public:
...
...
@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
RegularizationType
regularization_flag
,
const
MT
regularization_coeff
,
T
*
param_out
,
MT
*
velocity_out
,
MT
*
master_param_out
)
:
param_
(
param
),
...
...
@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
master_param_out_
(
master_param_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
...
...
@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
if
(
kRegType
==
RegularizationType
::
kL2DECAY
)
{
grad
+=
regularization_coeff_
*
param
;
}
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
;
...
...
@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
}
};
template
<
typename
T
,
typename
MT
>
class
DenseMomentumFunctor
<
T
,
MT
,
NoNesterov
>
{
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
>
class
DenseMomentumFunctor
<
T
,
MT
,
kRegType
,
NoNesterov
>
{
private:
const
T
*
param_
;
const
T
*
grad_
;
...
...
@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T
*
param_out_
;
MT
*
velocity_out_
;
MT
*
master_param_out_
;
const
RegularizationType
regularization_flag_
;
const
MT
regularization_coeff_
;
public:
...
...
@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
RegularizationType
regularization_flag
,
const
MT
regularization_coeff
,
T
*
param_out
,
MT
*
velocity_out
,
MT
*
master_param_out
)
:
param_
(
param
),
...
...
@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
master_param_out_
(
master_param_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
...
...
@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
if
(
kRegType
==
RegularizationType
::
kL2DECAY
)
{
grad
+=
regularization_coeff_
*
param
;
}
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
param_out
=
param
-
lr
*
velocity_out
;
...
...
@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
if
(
use_nesterov
)
{
DenseMomentumFunctor
<
T
,
MT
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
MT
>
(),
learning_rate
->
data
<
MPDType
>
(),
master_in_data
,
mu
,
rescale_grad
,
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
()),
master_out_data
);
for_range
(
functor
);
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param->numel(), regularization_coeff, \
param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range(functor);
if
(
use_nesterov
)
{
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
UseNesterov
,
RegularizationType
::
kL2DECAY
);
}
else
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
UseNesterov
,
RegularizationType
::
kNONE
);
}
}
else
{
DenseMomentumFunctor
<
T
,
MT
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
MT
>
()
,
learning_rate
->
data
<
MPDType
>
(),
master_in_data
,
mu
,
rescale_grad
,
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())
,
velocity_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
()),
master_out_data
);
for_range
(
functor
);
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
NoNesterov
,
RegularizationType
::
kL2DECAY
);
}
else
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
NoNesterov
,
RegularizationType
::
kNONE
);
}
}
}
...
...
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
class
Pow2DecayWithLinearWarmupOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
dim
=
framework
::
make_ddim
({
1
});
ctx
->
SetOutputDim
(
"LearningRateOut"
,
dim
);
ctx
->
SetOutputDim
(
"StepOut"
,
dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"LearningRate"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
Pow2DecayWithLinearWarmupOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"LearningRate"
,
"(Tensor) The input learning rate Tensor."
);
AddInput
(
"Step"
,
"(Tensor) The input global step Tensor."
);
AddOutput
(
"LearningRateOut"
,
"(Tensor) The output learning rate Tensor. Same with "
"Input(LearningRate)."
);
AddOutput
(
"StepOut"
,
"(Tensor) The output learning rate Tensor. Same with Input(Step)."
);
AddAttr
<
int64_t
>
(
"warmup_steps"
,
"(int64_t) The warmup steps."
);
AddAttr
<
int64_t
>
(
"total_steps"
,
"(int64_t) The total steps for changing the learning rate."
);
AddAttr
<
float
>
(
"base_lr"
,
"(float) The final learning rate value after warmup."
);
AddAttr
<
float
>
(
"end_lr"
,
"(float) The final learning rate value after total_steps."
);
AddComment
(
R"DOC(
The Pow2DecayWithLinearWarmup learning rate scheduler.
When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps
When warmup_steps <= step_num <= total_steps,
factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps)
lr = (base_lr - end_lr) * factor * factor + end_lr
When step_num > total_steps, lr = end_lr
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupOp
,
ops
::
Pow2DecayWithLinearWarmupOpMaker
);
REGISTER_OP_CPU_KERNEL
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
plat
::
CPUDeviceContext
,
float
>
);
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
plat
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
AttrT
>
struct
Pow2DecayWithLinearWarmupFunctor
{
template
<
typename
U
>
using
RestrictPtr
=
U
*
PADDLE_RESTRICT
;
public:
HOSTDEVICE
Pow2DecayWithLinearWarmupFunctor
(
RestrictPtr
<
T
>
lr
,
RestrictPtr
<
int64_t
>
step
,
size_t
warmup_steps
,
size_t
total_steps
,
AttrT
base_lr
,
AttrT
end_lr
)
:
lr_
(
lr
),
step_
(
step
),
warmup_steps_
(
warmup_steps
),
total_steps_
(
total_steps
),
base_lr_
(
base_lr
),
end_lr_
(
end_lr
)
{}
HOSTDEVICE
void
operator
()(
size_t
)
const
{
size_t
step
=
static_cast
<
size_t
>
(
*
step_
)
+
1
;
*
step_
=
static_cast
<
int64_t
>
(
step
);
if
(
step
<=
warmup_steps_
)
{
auto
new_lr
=
static_cast
<
double
>
(
step
)
/
warmup_steps_
*
base_lr_
;
*
lr_
=
static_cast
<
T
>
(
new_lr
);
}
else
if
(
step
<
total_steps_
)
{
auto
factor
=
1
-
static_cast
<
double
>
(
step
-
warmup_steps_
)
/
(
total_steps_
-
warmup_steps_
);
auto
new_lr
=
static_cast
<
double
>
(
base_lr_
-
end_lr_
)
*
(
factor
*
factor
)
+
end_lr_
;
*
lr_
=
static_cast
<
T
>
(
new_lr
);
}
else
{
*
lr_
=
static_cast
<
T
>
(
end_lr_
);
}
}
private:
RestrictPtr
<
T
>
lr_
;
RestrictPtr
<
int64_t
>
step_
;
size_t
warmup_steps_
;
size_t
total_steps_
;
AttrT
base_lr_
;
AttrT
end_lr_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
Pow2DecayWithLinearWarmupOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
auto
*
lr
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
const
auto
*
step
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Step"
);
auto
*
lr_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"LearningRateOut"
);
auto
*
step_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"StepOut"
);
PADDLE_ENFORCE_EQ
(
lr
,
lr_out
,
platform
::
errors
::
InvalidArgument
(
"Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."
));
PADDLE_ENFORCE_NOT_NULL
(
lr
,
platform
::
errors
::
InvalidArgument
(
"Input(LearingRate) should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
step
,
step_out
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) and Output(StepOut) must be the same."
));
PADDLE_ENFORCE_NOT_NULL
(
step
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
step
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) must be initialized."
));
auto
warmup_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"warmup_steps"
));
auto
total_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"total_steps"
));
PADDLE_ENFORCE_LE
(
warmup_steps
,
total_steps
,
platform
::
errors
::
InvalidArgument
(
"warmup_steps must not be larger than total_steps."
));
auto
base_lr
=
ctx
.
Attr
<
float
>
(
"base_lr"
);
auto
end_lr
=
ctx
.
Attr
<
float
>
(
"end_lr"
);
auto
*
lr_data
=
lr_out
->
data
<
T
>
();
auto
*
step_data
=
step_out
->
data
<
int64_t
>
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
1
);
using
AttrT
=
double
;
Pow2DecayWithLinearWarmupFunctor
<
T
,
AttrT
>
functor
(
lr_data
,
step_data
,
warmup_steps
,
total_steps
,
static_cast
<
AttrT
>
(
base_lr
),
static_cast
<
AttrT
>
(
end_lr
));
for_range
(
functor
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
287ca7d5
...
...
@@ -59,9 +59,14 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS})
cc_test
(
cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info
)
IF
(
WITH_GPU
)
nv_library
(
cuda_graph SRCS cuda_graph.cc DEPS enforce allocator_facade
)
nv_library
(
gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda
)
nv_library
(
cuda_profiler SRCS cuda_profiler.cc DEPS enforce
)
nv_library
(
cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph
)
ELSE
()
cc_library
(
cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade
)
ENDIF
()
IF
(
WITH_ROCM
)
hip_library
(
gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda
)
ENDIF
()
...
...
paddle/fluid/platform/cuda_graph.cc
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cuda_graph.h"
namespace
paddle
{
namespace
platform
{
std
::
unique_ptr
<
CUDAGraph
>
CUDAGraph
::
capturing_graph_
{
nullptr
};
void
CUDAGraph
::
Reset
()
{
if
(
is_reset_
)
return
;
#if CUDA_VERSION >= 10010
if
(
graph_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGraphDestroy
(
graph_
));
graph_
=
nullptr
;
}
if
(
exec_graph_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGraphExecDestroy
(
exec_graph_
));
exec_graph_
=
nullptr
;
}
#endif
// callback should be called in reverse order because the latter added
// callback may rely on the former added callback.
for
(
auto
iter
=
callbacks_
.
rbegin
();
iter
!=
callbacks_
.
rend
();
++
iter
)
{
(
*
iter
)();
}
callbacks_
.
clear
();
is_reset_
=
true
;
}
void
CUDAGraph
::
Replay
()
{
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ
(
is_reset_
,
false
,
errors
::
PermissionDenied
(
"Cannot replay the CUDA Graph after reset is called."
));
PADDLE_ENFORCE_NOT_NULL
(
exec_graph_
,
errors
::
PermissionDenied
(
"CUDA Graph must be captured before replaying."
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGraphLaunch
(
exec_graph_
,
stream_
));
#endif
}
void
CUDAGraph
::
BeginCapture
(
platform
::
CUDAPlace
place
,
cudaStream_t
stream
,
cudaStreamCaptureMode
mode
)
{
ThrowErrorIfNotSupportCUDAGraph
();
PADDLE_ENFORCE_EQ
(
IsCapturing
(),
false
,
errors
::
PermissionDenied
(
"CUDA Graph can only captured one by one."
));
PADDLE_ENFORCE_NOT_NULL
(
stream
,
errors
::
PermissionDenied
(
"CUDA Graph cannot be captured in default CUDA stream 0."
));
capturing_graph_
.
reset
(
new
CUDAGraph
());
capturing_graph_
->
place_
=
place
;
capturing_graph_
->
stream_
=
stream
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamBeginCapture
(
capturing_graph_
->
stream_
,
mode
));
cudaStreamCaptureStatus
status
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamGetCaptureInfo
(
capturing_graph_
->
stream_
,
&
status
,
&
(
capturing_graph_
->
id_
)));
PADDLE_ENFORCE_EQ
(
IsValidCapturing
(),
true
,
platform
::
errors
::
PermissionDenied
(
"CUDA Graph should not be invalidated."
));
VLOG
(
10
)
<<
"Begin to capture CUDA Graph with ID "
<<
capturing_graph_
->
id_
;
}
std
::
unique_ptr
<
CUDAGraph
>
CUDAGraph
::
EndCapture
()
{
ThrowErrorIfNotSupportCUDAGraph
();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ
(
IsCapturing
(),
true
,
errors
::
PermissionDenied
(
"No CUDA Graph is capturing."
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamEndCapture
(
capturing_graph_
->
stream_
,
&
(
capturing_graph_
->
graph_
)));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGraphInstantiate
(
&
(
capturing_graph_
->
exec_graph_
),
capturing_graph_
->
graph_
,
nullptr
,
nullptr
,
0
));
VLOG
(
10
)
<<
"End to capture CUDA Graph with ID "
<<
capturing_graph_
->
id_
;
return
std
::
move
(
capturing_graph_
);
#endif
}
bool
CUDAGraph
::
IsValidCapturing
()
{
if
(
!
IsCapturing
())
return
false
;
cudaStreamCaptureStatus
status
;
CUDAGraphID
id
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamGetCaptureInfo
(
capturing_graph_
->
stream_
,
&
status
,
&
id
));
return
status
==
cudaStreamCaptureStatusActive
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cuda_graph.h
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <mutex>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/fluid/platform/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
platform
{
#if CUDA_VERSION >= 10010
static
void
ThrowErrorIfNotSupportCUDAGraph
()
{}
#else
enum
cudaStreamCaptureMode
{
cudaStreamCaptureModeGlobal
=
0
,
cudaStreamCaptureModeThreadLocal
=
1
,
cudaStreamCaptureModeRelaxed
=
2
};
static
void
ThrowErrorIfNotSupportCUDAGraph
()
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"CUDA Graph is only supported when CUDA version >= 10.1"
));
}
#endif
// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
// the memory pool.
class
CUDAGraph
{
DISABLE_COPY_AND_ASSIGN
(
CUDAGraph
);
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph
()
{
ThrowErrorIfNotSupportCUDAGraph
();
}
public:
~
CUDAGraph
()
{
Reset
();
}
CUDAGraphID
ID
()
const
{
return
id_
;
}
void
Replay
();
void
Reset
();
void
AddResetCallback
(
std
::
function
<
void
()
>
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callbacks_
.
push_back
(
std
::
move
(
callback
));
}
static
void
BeginCapture
(
platform
::
CUDAPlace
place
,
cudaStream_t
stream
,
cudaStreamCaptureMode
mode
);
static
std
::
unique_ptr
<
CUDAGraph
>
EndCapture
();
static
void
AddResetCallbackDuringCapturing
(
std
::
function
<
void
()
>
callback
)
{
capturing_graph_
->
AddResetCallback
(
std
::
move
(
callback
));
}
// No need to add CUDA_VERSION macro because capturing_graph_ would
// always be nullptr (constructor throws error)
static
bool
IsCapturing
()
{
return
capturing_graph_
!=
nullptr
;
}
static
CUDAGraphID
CapturingID
()
{
return
capturing_graph_
->
id_
;
}
static
platform
::
CUDAPlace
CapturingPlace
()
{
return
capturing_graph_
->
place_
;
}
// This API can be used to debug which GPU operation is not
// supported during capturing CUDA Graph.
static
bool
IsValidCapturing
();
private:
#if CUDA_VERSION >= 10010
cudaGraph_t
graph_
{
nullptr
};
cudaGraphExec_t
exec_graph_
{
nullptr
};
#endif
cudaStream_t
stream_
{
nullptr
};
platform
::
CUDAPlace
place_
;
CUDAGraphID
id_
{
0
};
std
::
vector
<
std
::
function
<
void
()
>>
callbacks_
;
bool
is_reset_
{
false
};
std
::
mutex
mtx_
;
static
std
::
unique_ptr
<
CUDAGraph
>
capturing_graph_
;
};
#if CUDA_VERSION >= 10010
class
CUDAGraphCaptureModeGuard
{
DISABLE_COPY_AND_ASSIGN
(
CUDAGraphCaptureModeGuard
);
public:
explicit
CUDAGraphCaptureModeGuard
(
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
)
{
if
(
UNLIKELY
(
CUDAGraph
::
IsCapturing
()))
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
// After cudaThreadExchangeStreamCaptureMode is called,
// the variable "mode" would be set to the old capturing mode.
old_mode_
=
mode
;
}
}
~
CUDAGraphCaptureModeGuard
()
PADDLE_MAY_THROW
{
if
(
UNLIKELY
(
CUDAGraph
::
IsCapturing
()))
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaThreadExchangeStreamCaptureMode
(
&
old_mode_
));
}
}
private:
cudaStreamCaptureMode
old_mode_
;
};
#else
class
CUDAGraphCaptureModeGuard
{
DISABLE_COPY_AND_ASSIGN
(
CUDAGraphCaptureModeGuard
);
public:
explicit
CUDAGraphCaptureModeGuard
(
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
)
{}
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
platform
{
#ifdef PADDLE_WITH_CUDA
void
BeginCUDAGraphCapture
(
platform
::
CUDAPlace
place
,
cudaStreamCaptureMode
mode
)
{
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place
);
dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
auto
stream
=
dev_ctx
->
stream
();
CUDAGraph
::
BeginCapture
(
place
,
stream
,
mode
);
auto
id
=
CUDAGraph
::
CapturingID
();
memory
::
allocation
::
AllocatorFacade
::
Instance
().
PrepareMemoryPoolForCUDAGraph
(
id
);
AddResetCallbackIfCapturingCUDAGraph
([
id
]
{
memory
::
allocation
::
AllocatorFacade
::
Instance
().
RemoveMemoryPoolOfCUDAGraph
(
id
);
});
}
std
::
unique_ptr
<
CUDAGraph
>
EndCUDAGraphCapture
()
{
auto
place
=
CUDAGraph
::
CapturingPlace
();
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place
);
dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
return
CUDAGraph
::
EndCapture
();
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cuda_graph_with_memory_pool.h
0 → 100644
浏览文件 @
287ca7d5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph.h"
#endif
namespace
paddle
{
namespace
platform
{
// NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA
void
BeginCUDAGraphCapture
(
platform
::
CUDAPlace
place
,
cudaStreamCaptureMode
mode
);
std
::
unique_ptr
<
CUDAGraph
>
EndCUDAGraphCapture
();
#endif
inline
bool
IsCUDAGraphCapturing
()
{
#ifdef PADDLE_WITH_CUDA
return
CUDAGraph
::
IsCapturing
();
#else
return
false
;
#endif
}
inline
platform
::
CUDAPlace
CUDAGraphCapturingPlace
()
{
#ifdef PADDLE_WITH_CUDA
return
CUDAGraph
::
CapturingPlace
();
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"CUDA Graph is only supported on NVIDIA GPU device."
));
#endif
}
// Add reset callback if CUDA Graph is capturing.
// Otherwise, invoke callback directly.
template
<
typename
Callback
>
inline
void
AddResetCallbackIfCapturingCUDAGraph
(
Callback
&&
callback
)
{
#ifdef PADDLE_WITH_CUDA
if
(
UNLIKELY
(
IsCUDAGraphCapturing
()))
{
return
CUDAGraph
::
AddResetCallbackDuringCapturing
(
std
::
forward
<
Callback
>
(
callback
));
}
#endif
callback
();
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cudnn_desc.h
浏览文件 @
287ca7d5
...
...
@@ -44,6 +44,9 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) {
inline
std
::
vector
<
int
>
TransformDimOrder
(
const
std
::
vector
<
int
>&
dims
)
{
std
::
vector
<
int
>
transformed_dims
(
dims
.
begin
(),
dims
.
end
());
if
(
dims
.
size
()
<
4
)
{
return
transformed_dims
;
}
int
H
,
W
,
D
,
C
;
if
(
dims
.
size
()
==
4
)
{
H
=
dims
[
1
];
...
...
@@ -155,8 +158,8 @@ class TensorDescriptor {
dims_with_group
.
data
(),
strides
.
data
()));
}
void
set
(
const
Tensor
&
tensor
,
const
cudnnTensorFormat_t
format
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor
.
dims
());
void
set
(
const
std
::
vector
<
int
>&
dims
,
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
dtype
)
{
std
::
vector
<
int
>
transformed_dims
;
if
(
format
==
CUDNN_TENSOR_NHWC
)
{
transformed_dims
=
TransformDimOrder
(
dims
);
...
...
@@ -164,8 +167,14 @@ class TensorDescriptor {
transformed_dims
=
dims
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetTensorNdDescriptorEx
(
desc_
.
get
(),
format
,
ToCudnnDataType
(
tensor
.
type
()),
transformed_dims
.
size
(),
transformed_dims
.
data
()));
desc_
.
get
(),
format
,
dtype
,
transformed_dims
.
size
(),
transformed_dims
.
data
()));
}
void
set
(
const
Tensor
&
tensor
,
const
cudnnTensorFormat_t
format
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor
.
dims
());
auto
dtype
=
ToCudnnDataType
(
tensor
.
type
());
set
(
dims
,
format
,
dtype
);
}
private:
...
...
@@ -191,9 +200,8 @@ class FilterDescriptor {
T
*
desc
()
{
return
desc_
.
get
();
}
T
*
desc
()
const
{
return
desc_
.
get
();
}
void
set
(
const
Tensor
&
tensor
,
const
cudnnTensorFormat_t
format
,
const
int
groups
=
1
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor
.
dims
());
void
set
(
const
std
::
vector
<
int
>&
dims
,
const
cudnnTensorFormat_t
format
,
const
cudnnDataType_t
dtype
,
const
int
groups
=
1
)
{
std
::
vector
<
int
>
transformed_dims
;
if
(
format
==
CUDNN_TENSOR_NHWC
)
{
transformed_dims
=
TransformDimOrder
(
dims
);
...
...
@@ -204,8 +212,15 @@ class FilterDescriptor {
transformed_dims
[
1
]
=
transformed_dims
[
1
]
/
groups
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetFilterNdDescriptor
(
desc_
.
get
(),
ToCudnnDataType
(
tensor
.
type
()),
format
,
transformed_dims
.
size
(),
transformed_dims
.
data
()));
desc_
.
get
(),
dtype
,
format
,
transformed_dims
.
size
(),
transformed_dims
.
data
()));
}
void
set
(
const
Tensor
&
tensor
,
const
cudnnTensorFormat_t
format
,
const
int
groups
=
1
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor
.
dims
());
auto
dtype
=
ToCudnnDataType
(
tensor
.
type
());
set
(
dims
,
format
,
dtype
,
groups
);
}
private:
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
287ca7d5
...
...
@@ -180,7 +180,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 8000
#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) __macro(cudnnSetRNNDescriptor_v8);
#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \
__macro(cudnnSetRNNDescriptor_v8); \
__macro(cudnnCreateFusedOpsPlan); \
__macro(cudnnCreateFusedOpsConstParamPack); \
__macro(cudnnCreateFusedOpsVariantParamPack); \
__macro(cudnnDestroyFusedOpsPlan); \
__macro(cudnnDestroyFusedOpsConstParamPack); \
__macro(cudnnDestroyFusedOpsVariantParamPack); \
__macro(cudnnFusedOpsExecute); \
__macro(cudnnSetFusedOpsConstParamPackAttribute); \
__macro(cudnnSetFusedOpsVariantParamPackAttribute); \
__macro(cudnnMakeFusedOpsPlan);
CUDNN_DNN_ROUTINE_EACH_R8
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
#endif
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
287ca7d5
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
#include "paddle/fluid/platform/cuda_graph.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#endif
#include "paddle/fluid/memory/malloc.h"
...
...
@@ -557,6 +558,7 @@ class RecordedCudaMallocHelper {
#ifdef PADDLE_WITH_HIP
auto
result
=
hipMalloc
(
ptr
,
size
);
#else
CUDAGraphCaptureModeGuard
capture_mode_guard
;
auto
result
=
cudaMalloc
(
ptr
,
size
);
#endif
if
(
result
==
gpuSuccess
)
{
...
...
paddle/fluid/platform/macros.h
浏览文件 @
287ca7d5
...
...
@@ -30,3 +30,9 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
paddle/fluid/platform/type_defs.h
浏览文件 @
287ca7d5
...
...
@@ -36,4 +36,5 @@ using gpuEvent_t = cudaEvent_t;
using
gpuDeviceProp
=
cudaDeviceProp
;
#endif
using
CUDAGraphID
=
unsigned
long
long
;
// NOLINT
}
// namespace paddle
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
287ca7d5
...
...
@@ -7,7 +7,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model
)
cost_model
cuda_graph_with_memory_pool
)
if
(
WITH_PSCORE
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
ps_service
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
287ca7d5
...
...
@@ -125,6 +125,8 @@ limitations under the License. */
#include "paddle/fluid/platform/xpu/xpu_info.h"
#endif
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#ifdef PADDLE_WITH_CRYPTO
#include "paddle/fluid/pybind/crypto.h"
#endif
...
...
@@ -485,6 +487,17 @@ static int GetNCCLVersion() {
}
#endif
template
<
typename
PlaceType
>
static
void
TensorCopyFrom
(
framework
::
Tensor
*
dst
,
const
framework
::
Tensor
&
src
,
const
PlaceType
&
place
,
int64_t
batch_size
)
{
if
(
batch_size
<
0
)
{
framework
::
TensorCopy
(
src
,
place
,
dst
);
}
else
{
auto
sliced
=
src
.
Slice
(
0
,
batch_size
);
framework
::
TensorCopy
(
sliced
,
place
,
dst
);
}
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE
(
core_avx
,
m
)
{
#else
...
...
@@ -520,6 +533,19 @@ PYBIND11_MODULE(core_noavx, m) {
m
.
def
(
"nccl_version"
,
&
GetNCCLVersion
);
#endif
m
.
def
(
"is_cuda_graph_capturing"
,
&
platform
::
IsCUDAGraphCapturing
);
#ifdef PADDLE_WITH_CUDA
py
::
class_
<
platform
::
CUDAGraph
>
(
m
,
"CUDAGraph"
)
.
def_static
(
"begin_capture"
,
[](
platform
::
CUDAPlace
place
,
int
mode
)
{
platform
::
BeginCUDAGraphCapture
(
place
,
static_cast
<
cudaStreamCaptureMode
>
(
mode
));
})
.
def_static
(
"end_capture"
,
&
platform
::
EndCUDAGraphCapture
)
.
def
(
"replay"
,
&
platform
::
CUDAGraph
::
Replay
)
.
def
(
"reset"
,
&
platform
::
CUDAGraph
::
Reset
);
#endif
m
.
def
(
"wait_device"
,
[](
const
platform
::
Place
&
place
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
)
->
Wait
();
});
...
...
@@ -721,6 +747,18 @@ PYBIND11_MODULE(core_noavx, m) {
paddle
::
framework
::
proto
::
VarType
::
Type
type
)
{
return
reinterpret_cast
<
uintptr_t
>
(
self
.
mutable_data
(
place
,
type
));
})
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
CPUPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
XPUPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
CUDAPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
NPUPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
CUDAPinnedPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"_copy_from"
,
&
TensorCopyFrom
<
paddle
::
platform
::
Place
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"batch_size"
)
=
-
1
)
.
def
(
"set"
,
SetTensorFromPyArray
<
paddle
::
platform
::
CPUPlace
>
,
py
::
arg
(
"array"
),
py
::
arg
(
"place"
),
py
::
arg
(
"zero_copy"
)
=
false
)
.
def
(
"set"
,
SetTensorFromPyArray
<
paddle
::
platform
::
XPUPlace
>
,
...
...
@@ -2301,7 +2339,14 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"op_support_gpu"
,
OpSupportGPU
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m
.
def
(
"get_cuda_device_count"
,
platform
::
GetCUDADeviceCount
);
m
.
def
(
"cuda_empty_cache"
,
platform
::
EmptyCache
);
m
.
def
(
"cuda_empty_cache"
,
[]
{
for
(
int
dev_id
:
platform
::
GetSelectedDevices
())
{
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
platform
::
CUDAPlace
(
dev_id
));
dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
}
platform
::
EmptyCache
();
});
m
.
def
(
"get_device_properties"
,
[](
int
id
)
->
const
gpuDeviceProp
&
{
return
platform
::
GetDeviceProperties
(
id
);
...
...
@@ -3213,6 +3258,13 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
bool
fix_op_run_order
)
{
self
.
fix_op_run_order_
=
fix_op_run_order
;
})
.
def_property
(
"allow_cuda_graph_capture"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
allow_cuda_graph_capture_
;
},
[](
BuildStrategy
&
self
,
bool
allow_cuda_graph_capture
)
{
self
.
allow_cuda_graph_capture_
=
allow_cuda_graph_capture
;
})
.
def
(
"_copy"
,
[](
const
BuildStrategy
&
self
)
{
auto
new_bs
=
self
;
...
...
python/paddle/device/cuda/graphs.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid.core
import
is_compiled_with_cuda
,
is_compiled_with_rocm
,
CUDAPlace
if
is_compiled_with_cuda
()
and
not
is_compiled_with_rocm
():
from
paddle.fluid.core
import
CUDAGraph
as
CoreCUDAGraph
class
CUDAGraph
:
def
__init__
(
self
,
place
=
None
,
mode
=
"thread_local"
):
ALL_MODES
=
[
"global"
,
"thread_local"
,
"relaxed"
]
self
.
_graph
=
None
if
place
is
None
:
place
=
CUDAPlace
(
0
)
self
.
_place
=
place
assert
mode
in
ALL_MODES
self
.
_mode
=
ALL_MODES
.
index
(
mode
)
def
capture_begin
(
self
):
CoreCUDAGraph
.
begin_capture
(
self
.
_place
,
self
.
_mode
)
def
capture_end
(
self
):
self
.
_graph
=
CoreCUDAGraph
.
end_capture
()
def
replay
(
self
):
self
.
_graph
.
replay
()
def
reset
(
self
):
self
.
_graph
.
reset
()
else
:
class
CUDAGraph
:
def
__init__
(
self
,
place
=
None
,
mode
=
"thread_local"
):
raise
NotImplementedError
()
def
capture_begin
(
self
):
raise
NotImplementedError
()
def
capture_end
(
self
):
raise
NotImplementedError
()
def
replay
(
self
):
raise
NotImplementedError
()
def
reset
(
self
):
raise
NotImplementedError
()
python/paddle/fluid/contrib/layers/nn.py
浏览文件 @
287ca7d5
...
...
@@ -1932,3 +1932,38 @@ def fused_bn_add_act(x,
attrs
=
attrs
)
return
batch_norm_out
def
pow2_decay_with_linear_warmup
(
warmup_steps
,
total_steps
,
base_lr
,
end_lr
,
dtype
=
'float32'
,
name
=
None
):
if
paddle
.
fluid
.
in_dygraph_mode
():
raise
NotImplementedError
(
"pow2_decay_with_linear_warmup does not support dygraph mode yet."
)
helper
=
LayerHelper
(
"pow2_decay_with_linear_warmup"
,
**
locals
())
lr
=
helper
.
create_global_variable
(
persistable
=
True
,
dtype
=
dtype
,
shape
=
[
1
])
helper
.
set_variable_initializer
(
lr
,
Constant
(
value
=
float
(
base_lr
)
/
warmup_steps
))
step
=
helper
.
create_global_variable
(
persistable
=
True
,
dtype
=
'int64'
,
shape
=
[
1
])
helper
.
set_variable_initializer
(
step
,
Constant
(
value
=
0
))
assert
warmup_steps
<=
total_steps
,
"warmup_steps cannot be larger than total_steps"
helper
.
append_op
(
type
=
"pow2_decay_with_linear_warmup"
,
inputs
=
{
"LearningRate"
:
lr
,
"Step"
:
step
},
outputs
=
{
"LearningRateOut"
:
lr
,
"StepOut"
:
step
},
attrs
=
{
"warmup_steps"
:
warmup_steps
,
"total_steps"
:
total_steps
,
"base_lr"
:
base_lr
,
"end_lr"
:
end_lr
,
})
return
lr
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
287ca7d5
...
...
@@ -127,11 +127,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops
=
0
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
]:
if
in_name
not
in
{
'X'
,
'Z'
}:
continue
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dest_dtype
:
...
...
@@ -184,9 +182,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dest_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
for
out_name
in
op
.
output_names
:
if
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
]
and
out_name
!=
'Y'
:
if
_keep_fp32_output
(
op
,
out_name
):
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
block
.
var
(
out_var_name
)
...
...
@@ -401,9 +397,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
keep_fp32_ops
.
add
(
op
)
continue
# processed below
for
in_name
in
op
.
input_names
:
if
op
.
type
in
{
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
}
and
in_name
not
in
{
'X'
,
'Z'
}:
if
_keep_fp32_input
(
op
,
in_name
):
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
None
...
...
@@ -431,9 +425,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
format
(
op
.
type
,
in_var_name
,
in_var
.
dtype
))
for
out_name
in
op
.
output_names
:
if
op
.
type
in
{
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
}
and
out_name
!=
'Y'
:
if
_keep_fp32_output
(
op
,
out_name
):
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
None
...
...
python/paddle/fluid/executor.py
浏览文件 @
287ca7d5
...
...
@@ -1041,9 +1041,15 @@ class Executor(object):
lr_value
=
lr_sheduler
()
lr_var
=
program
.
_program
.
global_block
().
vars
[
lr_sheduler
.
_var_name
]
lr_tensor
=
_as_lodtensor
(
lr_value
,
core
.
CPUPlace
(),
lr_var
.
dtype
)
exe
.
feed_and_split_tensor_into_local_scopes
({
lr_sheduler
.
_var_name
:
lr_tensor
})
if
core
.
is_cuda_graph_capturing
():
warnings
.
warn
(
"Caution!!! When capturing CUDA Graph, the learning rate scheduler would not "
"take any effect! Please set the learning rate manually before each batch!"
)
else
:
exe
.
feed_and_split_tensor_into_local_scopes
({
lr_sheduler
.
_var_name
:
lr_tensor
})
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetch_list
))
tensors
=
exe
.
run
(
fetch_var_names
,
return_merged
).
_move_to_list
()
...
...
python/paddle/fluid/memory_analysis.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
core
import
numpy
as
np
def
get_var_and_memory_size
(
block
,
var_name
,
batch_size
=
None
):
var
=
block
.
_find_var_recursive
(
var_name
)
assert
var
is
not
None
,
"Variable {} cannot be found"
.
format
(
var_name
)
assert
var
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
"Variable {} is not Tensor"
.
format
(
var_name
)
shape
=
list
(
var
.
shape
)
if
not
shape
:
return
var
,
0
has_none
=
False
for
i
,
s
in
enumerate
(
shape
):
if
s
is
None
or
s
<
0
:
assert
not
has_none
shape
[
i
]
=
batch_size
has_none
=
True
assert
all
(
[
s
>=
0
for
s
in
shape
]),
"shape {} is not deterministic"
.
format
(
shape
)
mem_size
=
int
(
np
.
prod
(
shape
))
*
core
.
size_of_dtype
(
var
.
dtype
)
return
var
,
mem_size
def
pre_allocate_memory
(
size
,
place
):
t
=
core
.
LoDTensor
()
t
.
_set_dims
([
size
])
t
.
_mutable_data
(
place
,
core
.
VarDesc
.
VarType
.
INT8
)
del
t
# NOTE: does not consider inplace yet.
def
get_max_memory_info
(
program
,
batch_size
=
None
):
assert
program
.
num_blocks
==
1
,
"only support to analysis program with only one block"
cur_tmp_mem
=
0
max_tmp_mem
=
0
max_persistable_mem
=
0
visited_vars
=
set
()
alived_vars
=
[]
block
=
program
.
global_block
()
gc_vars
=
core
.
_get_eager_deletion_vars
(
program
.
desc
,
[])[
0
]
for
i
,
op
in
enumerate
(
block
.
ops
):
var_names
=
op
.
input_arg_names
+
op
.
output_arg_names
for
var_name
in
var_names
:
if
var_name
in
visited_vars
:
continue
visited_vars
.
add
(
var_name
)
var
,
mem_size
=
get_var_and_memory_size
(
block
,
var_name
,
batch_size
)
if
var
.
persistable
:
max_persistable_mem
+=
mem_size
else
:
cur_tmp_mem
+=
mem_size
max_tmp_mem
=
max
(
max_tmp_mem
,
cur_tmp_mem
)
cur_gc_vars
=
gc_vars
[
i
]
for
var_name
in
var_names
:
if
var_name
not
in
cur_gc_vars
:
continue
_
,
mem_size
=
get_var_and_memory_size
(
block
,
var_name
,
batch_size
)
cur_tmp_mem
-=
mem_size
return
max_tmp_mem
,
max_persistable_mem
python/paddle/fluid/optimizer.py
浏览文件 @
287ca7d5
...
...
@@ -2064,8 +2064,9 @@ class LarsMomentumOptimizer(Optimizer):
attrs
=
{
"mu"
:
self
.
_momentum
,
"lars_coeff"
:
self
.
_lars_coeff
,
"lars_weight_decay"
:
_lars_weight_decay
,
"lars_weight_decay"
:
[
_lars_weight_decay
]
,
"multi_precision"
:
find_master
,
"epsilon"
:
self
.
_epsilon
,
"rescale_grad"
:
self
.
_rescale_grad
}
...
...
@@ -2084,7 +2085,7 @@ class LarsMomentumOptimizer(Optimizer):
# create the momentum optimize op
momentum_op
=
block
.
append_op
(
type
=
self
.
type
,
type
=
self
.
type
if
_lars_weight_decay
!=
0.0
else
'momentum'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
,
...
...
python/paddle/fluid/tests/unittests/test_cuda_graph.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
from
paddle.device.cuda.graphs
import
CUDAGraph
import
unittest
import
numpy
as
np
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
simple_nets
import
simple_fc_net_with_inputs
class
TestCUDAGraph
(
unittest
.
TestCase
):
def
setUp
(
self
):
if
paddle
.
is_compiled_with_cuda
()
and
not
paddle
.
is_compiled_with_rocm
(
):
fluid
.
set_flags
({
'FLAGS_allocator_strategy'
:
'auto_growth'
,
'FLAGS_sync_nccl_allreduce'
:
False
,
'FLAGS_cudnn_deterministic'
:
True
})
def
random_tensor
(
self
,
shape
):
return
paddle
.
to_tensor
(
np
.
random
.
randint
(
low
=
0
,
high
=
10
,
size
=
shape
).
astype
(
"float32"
))
@
switch_to_static_graph
def
test_cuda_graph_static_graph
(
self
):
if
not
paddle
.
is_compiled_with_cuda
()
or
paddle
.
is_compiled_with_rocm
():
return
seed
=
100
loss_cuda_graph
=
self
.
cuda_graph_static_graph_main
(
seed
,
use_cuda_graph
=
True
)
loss_no_cuda_graph
=
self
.
cuda_graph_static_graph_main
(
seed
,
use_cuda_graph
=
False
)
self
.
assertEqual
(
loss_cuda_graph
,
loss_no_cuda_graph
)
def
cuda_graph_static_graph_main
(
self
,
seed
,
use_cuda_graph
):
batch_size
=
1
class_num
=
10
image_shape
=
[
batch_size
,
784
]
label_shape
=
[
batch_size
,
1
]
paddle
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
startup
=
paddle
.
static
.
Program
()
main
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
image
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
image_shape
,
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
"label"
,
shape
=
label_shape
,
dtype
=
'int64'
)
image
.
persistable
=
True
label
.
persistable
=
True
loss
=
simple_fc_net_with_inputs
(
image
,
label
,
class_num
)
loss
.
persistable
=
True
lr
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
,
3
,
4
],
values
=
[
0.01
,
0.02
,
0.03
,
0.04
])
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
lr
)
optimizer
.
minimize
(
loss
)
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
scope
=
paddle
.
static
.
Scope
()
with
paddle
.
static
.
scope_guard
(
scope
):
exe
.
run
(
startup
)
build_strategy
=
paddle
.
static
.
BuildStrategy
()
build_strategy
.
allow_cuda_graph_capture
=
True
build_strategy
.
fix_op_run_order
=
True
build_strategy
.
fuse_all_optimizer_ops
=
True
compiled_program
=
paddle
.
static
.
CompiledProgram
(
main
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
places
=
place
)
image_t
=
scope
.
var
(
image
.
name
).
get_tensor
()
label_t
=
scope
.
var
(
label
.
name
).
get_tensor
()
loss_t
=
scope
.
var
(
loss
.
name
).
get_tensor
()
lr_var
=
main
.
global_block
().
var
(
lr
.
_var_name
)
self
.
assertTrue
(
lr_var
.
persistable
)
lr_t
=
scope
.
var
(
lr_var
.
name
).
get_tensor
()
cuda_graph
=
None
for
batch_id
in
range
(
20
):
image_t
.
set
(
np
.
random
.
rand
(
*
image_shape
).
astype
(
'float32'
),
place
)
label_t
.
set
(
np
.
random
.
randint
(
low
=
0
,
high
=
class_num
,
size
=
label_shape
,
dtype
=
'int64'
),
place
)
if
batch_id
==
1
and
use_cuda_graph
:
cuda_graph
=
CUDAGraph
(
place
,
mode
=
"global"
)
cuda_graph
.
capture_begin
()
exe
.
run
(
compiled_program
)
cuda_graph
.
capture_end
()
if
cuda_graph
:
lr_t
.
set
(
np
.
array
([
lr
()],
dtype
=
'float32'
),
place
)
cuda_graph
.
replay
()
else
:
exe
.
run
(
compiled_program
)
lr
.
step
()
if
cuda_graph
:
cuda_graph
.
reset
()
return
np
.
array
(
loss_t
)
def
test_cuda_graph_dynamic_graph
(
self
):
if
not
paddle
.
is_compiled_with_cuda
()
or
paddle
.
is_compiled_with_rocm
():
return
shape
=
[
2
,
3
]
x
=
self
.
random_tensor
(
shape
)
z
=
self
.
random_tensor
(
shape
)
g
=
CUDAGraph
()
g
.
capture_begin
()
y
=
x
+
10
z
.
add_
(
x
)
g
.
capture_end
()
for
_
in
range
(
10
):
z_np_init
=
z
.
numpy
()
x_new
=
self
.
random_tensor
(
shape
)
x
.
copy_
(
x_new
,
False
)
g
.
replay
()
x_np
=
x_new
.
numpy
()
y_np
=
y
.
numpy
()
z_np
=
z
.
numpy
()
self
.
assertTrue
((
y_np
-
x_np
==
10
).
all
())
self
.
assertTrue
((
z_np
-
z_np_init
==
x_np
).
all
())
g
.
reset
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py
浏览文件 @
287ca7d5
...
...
@@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase):
'op_role_var'
)[
0
]
or
".b"
in
op
.
attr
(
'op_role_var'
)[
0
])
]
for
op
in
ops_without_wd
:
self
.
assertEqual
(
op
.
attr
(
'lars_weight_decay'
),
0
)
self
.
assertEqual
(
op
.
attr
(
'lars_weight_decay'
)
[
0
]
,
0
)
def
test_lars_apply_with_amp
(
self
):
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
...
...
python/paddle/fluid/tests/unittests/test_memory_analysis.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
from
paddle.fluid.memory_analysis
import
pre_allocate_memory
,
get_max_memory_info
from
simple_nets
import
simple_fc_net
class
TestMemoryAnalysis
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
def
test_get_memory_info
(
self
):
loss
=
simple_fc_net
()
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
optimizer
.
minimize
(
loss
)
main_prog
=
paddle
.
static
.
default_main_program
()
max_tmp_mem_1
,
max_persitable_mem_1
=
get_max_memory_info
(
main_prog
,
batch_size
=
32
)
self
.
assertGreater
(
max_tmp_mem_1
,
0
)
self
.
assertGreater
(
max_persitable_mem_1
,
0
)
max_tmp_mem_2
,
max_persitable_mem_2
=
get_max_memory_info
(
main_prog
,
batch_size
=
64
)
self
.
assertEqual
(
max_persitable_mem_1
,
max_persitable_mem_2
)
self
.
assertLess
(
max_tmp_mem_1
,
max_tmp_mem_2
)
class
TestPreAllocateMemory
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
def
test_pre_allocate
(
self
):
size
=
32
*
1024
*
1024
pre_allocate_memory
(
size
,
paddle
.
CPUPlace
())
if
paddle
.
is_compiled_with_cuda
():
pre_allocate_memory
(
size
,
paddle
.
CUDAPlace
(
0
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_merged_momentum_op.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
numpy
as
np
from
paddle.fluid.layer_helper
import
LayerHelper
from
collections
import
OrderedDict
def
run_momentum_op
(
params
,
grads
,
velocitys
,
master_params
,
learning_rate
,
place
,
multi_precision
,
mu
=
0.9
,
rescale_grad
=
0.01
,
use_merged
=
False
):
assert
len
(
params
)
==
len
(
grads
)
assert
len
(
params
)
==
len
(
velocitys
)
if
multi_precision
:
assert
len
(
params
)
==
len
(
master_params
)
op_type
=
'merged_momentum'
if
use_merged
else
'momentum'
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
helper
=
LayerHelper
(
op_type
,
**
locals
())
attrs
=
{
'mu'
:
mu
,
'multi_precision'
:
multi_precision
,
'rescale_grad'
:
rescale_grad
,
}
param_vars
=
[
helper
.
create_variable
(
persistable
=
True
,
shape
=
p
.
shape
,
dtype
=
p
.
dtype
)
for
p
in
params
]
grad_vars
=
[
helper
.
create_variable
(
shape
=
g
.
shape
,
dtype
=
g
.
dtype
)
for
g
in
grads
]
velocity_vars
=
[
helper
.
create_variable
(
persistable
=
True
,
shape
=
v
.
shape
,
dtype
=
v
.
dtype
)
for
v
in
velocitys
]
lr_var
=
helper
.
create_variable
(
persistable
=
True
,
shape
=
learning_rate
.
shape
,
dtype
=
learning_rate
.
dtype
)
feed_dict
=
OrderedDict
()
feed_dict
.
update
(
OrderedDict
([(
p_var
.
name
,
p_val
)
for
p_var
,
p_val
in
zip
(
param_vars
,
params
)]))
feed_dict
.
update
(
OrderedDict
([(
v_var
.
name
,
v_val
)
for
v_var
,
v_val
in
zip
(
velocity_vars
,
velocitys
)]))
fetch_list
=
list
(
feed_dict
.
keys
())
feed_dict
.
update
(
OrderedDict
([(
g_var
.
name
,
g_val
)
for
g_var
,
g_val
in
zip
(
grad_vars
,
grads
)]))
feed_dict
.
update
({
lr_var
.
name
:
learning_rate
})
if
multi_precision
:
master_param_vars
=
[
helper
.
create_variable
(
persistable
=
True
,
shape
=
p
.
shape
,
dtype
=
p
.
dtype
)
for
p
in
master_params
]
feed_dict
.
update
(
OrderedDict
([(
mp_var
.
name
,
mp_val
)
for
mp_var
,
mp_val
in
zip
(
master_param_vars
,
master_params
)]))
# CPUPlace does not use MasterParam
if
isinstance
(
place
,
paddle
.
CUDAPlace
):
fetch_list
=
fetch_list
+
[
mp_var
.
name
for
mp_var
in
master_param_vars
]
else
:
master_param_vars
=
None
if
not
use_merged
:
for
i
,
(
p
,
g
,
v
)
in
enumerate
(
zip
(
param_vars
,
grad_vars
,
velocity_vars
)):
inputs
=
{
'Param'
:
p
,
'Grad'
:
g
,
'Velocity'
:
v
,
'LearningRate'
:
lr_var
,
}
outputs
=
{
'ParamOut'
:
p
,
'VelocityOut'
:
v
}
if
multi_precision
:
inputs
[
'MasterParam'
]
=
master_param_vars
[
i
]
outputs
[
'MasterParamOut'
]
=
master_param_vars
[
i
]
helper
.
append_op
(
type
=
op_type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
else
:
inputs
=
{
'Param'
:
param_vars
,
'Grad'
:
grad_vars
,
'Velocity'
:
velocity_vars
,
'LearningRate'
:
lr_var
,
}
outputs
=
{
'ParamOut'
:
param_vars
,
'VelocityOut'
:
velocity_vars
}
if
multi_precision
:
inputs
[
'MasterParam'
]
=
master_param_vars
outputs
[
'MasterParamOut'
]
=
master_param_vars
helper
.
append_op
(
type
=
op_type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
exe
=
paddle
.
static
.
Executor
(
place
)
with
paddle
.
static
.
scope_guard
(
paddle
.
static
.
Scope
()):
exe
.
run
(
startup
)
return
exe
.
run
(
main
,
feed
=
feed_dict
,
fetch_list
=
fetch_list
)
class
TestMergedMomentum
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
shapes
=
[[
3
,
4
],
[
2
,
7
],
[
5
,
6
],
[
7
,
8
]]
self
.
seed
=
10
def
gen_rand_data
(
self
,
shapes
,
dtype
):
return
[
np
.
random
.
random
(
s
).
astype
(
dtype
)
for
s
in
shapes
]
def
prepare_data
(
self
,
shapes
,
multi_precision
,
seed
,
place
):
np
.
random
.
seed
(
seed
)
mp_dtype
=
np
.
float32
dtype
=
np
.
float16
if
multi_precision
and
isinstance
(
place
,
paddle
.
CUDAPlace
)
else
np
.
float32
params
=
self
.
gen_rand_data
(
shapes
,
dtype
)
grads
=
self
.
gen_rand_data
(
shapes
,
dtype
)
velocitys
=
self
.
gen_rand_data
(
shapes
,
mp_dtype
)
learning_rate
=
self
.
gen_rand_data
([[
1
]],
mp_dtype
)[
0
]
if
multi_precision
:
master_params
=
[
p
.
astype
(
mp_dtype
)
for
p
in
params
]
else
:
master_params
=
None
return
params
,
grads
,
velocitys
,
master_params
,
learning_rate
def
check_with_place
(
self
,
place
,
multi_precision
):
params
,
grads
,
velocitys
,
master_params
,
learning_rate
=
self
.
prepare_data
(
self
.
shapes
,
multi_precision
,
self
.
seed
,
place
)
def
run_op
(
use_merged
):
# FIXME(zengjinle): CPU Momentum Op does not support rescale_grad
rescale_grad
=
1.0
if
isinstance
(
place
,
paddle
.
CPUPlace
)
else
0.01
return
run_momentum_op
(
params
,
grads
,
velocitys
,
master_params
,
learning_rate
,
place
,
multi_precision
,
rescale_grad
=
rescale_grad
,
use_merged
=
use_merged
)
outs1
=
run_op
(
True
)
outs2
=
run_op
(
False
)
self
.
assertEqual
(
len
(
outs1
),
len
(
outs2
))
for
i
,
(
out1
,
out2
)
in
enumerate
(
zip
(
outs1
,
outs2
)):
if
isinstance
(
place
,
paddle
.
CUDAPlace
):
self
.
assertTrue
(
np
.
array_equal
(
out1
,
out2
))
else
:
self
.
assertTrue
(
np
.
allclose
(
out1
,
out2
,
atol
=
1e-7
))
def
get_places
(
self
):
places
=
[
paddle
.
CPUPlace
()]
if
paddle
.
is_compiled_with_cuda
():
places
.
append
(
paddle
.
CUDAPlace
(
0
))
return
places
def
test_main
(
self
):
for
multi_precision
in
[
False
,
True
]:
for
place
in
self
.
get_places
():
self
.
check_with_place
(
place
,
multi_precision
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
287ca7d5
...
...
@@ -138,50 +138,70 @@ class TestMomentumOp2(OpTest):
"core is not compiled with CUDA"
)
class
TestLarsMomentumOpWithMP
(
OpTest
):
def
setUp
(
self
):
self
.
config
()
self
.
op_type
=
"lars_momentum"
master_param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
param
=
master_param
.
astype
(
"float16"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float16"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
mu
=
0.0001
lars_coeff
=
0.001
lars_weight_decay
=
0.0005
rescale_grad
=
1.0
params
=
[]
grads
=
[]
velocitys
=
[]
learning_rates
=
[]
master_params
=
[]
param_outs
=
[]
velocity_outs
=
[]
master_param_outs
=
[]
for
i
in
range
(
self
.
params_num
):
master_param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
param
=
master_param
.
astype
(
"float16"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float16"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
fp32_grad
=
grad
.
astype
(
"float32"
)
pnorm
=
np
.
sqrt
(
np
.
square
(
master_param
).
sum
())
gnorm
=
np
.
sqrt
(
np
.
square
(
fp32_grad
).
sum
())
local_lr
=
learning_rate
*
lars_coeff
*
pnorm
/
(
gnorm
+
lars_weight_decay
*
pnorm
)
fp32_grad
=
fp32_grad
*
rescale_grad
velocity_out
=
mu
*
velocity
+
local_lr
*
(
fp32_grad
+
lars_weight_decay
*
master_param
)
p_new
=
master_param
-
velocity_out
param_out
=
p_new
.
astype
(
"float16"
)
master_param_out
=
p_new
params
.
append
((
"SubParam_"
+
str
(
i
),
param
))
grads
.
append
((
"SubGrad_"
+
str
(
i
),
grad
))
velocitys
.
append
((
"SubVelocity_"
+
str
(
i
),
velocity
))
learning_rates
.
append
((
"SubLearning_rate_"
+
str
(
i
),
learning_rate
))
velocity_outs
.
append
((
"SubVelocity_out_"
+
str
(
i
),
velocity_out
))
param_outs
.
append
((
"SubParam_out_"
+
str
(
i
),
param_out
))
master_params
.
append
((
"SubMasterParam_"
+
str
(
i
),
master_param
))
master_param_outs
.
append
(
(
"SubMasterParamOut_"
+
str
(
i
),
master_param_out
))
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
,
'MasterParam'
:
master_param
,
'Param'
:
param
s
,
'Grad'
:
grad
s
,
'Velocity'
:
velocity
s
,
'LearningRate'
:
learning_rate
s
,
'MasterParam'
:
master_param
s
,
}
self
.
attrs
=
{
'mu'
:
mu
,
'lars_coeff'
:
lars_coeff
,
'lars_weight_decay'
:
lars_weight_decay
,
'lars_weight_decay'
:
[
lars_weight_decay
]
,
'multi_precision'
:
True
,
'rescale_grad'
:
rescale_grad
}
fp32_grad
=
grad
.
astype
(
"float32"
)
pnorm
=
np
.
sqrt
(
np
.
square
(
master_param
).
sum
())
gnorm
=
np
.
sqrt
(
np
.
square
(
fp32_grad
).
sum
())
local_lr
=
learning_rate
*
lars_coeff
*
pnorm
/
(
gnorm
+
lars_weight_decay
*
pnorm
)
fp32_grad
=
fp32_grad
*
rescale_grad
velocity_out
=
mu
*
velocity
+
local_lr
*
(
fp32_grad
+
lars_weight_decay
*
master_param
)
p_new
=
master_param
-
velocity_out
param_out
=
p_new
.
astype
(
"float16"
)
master_param_out
=
p_new
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
,
'MasterParamOut'
:
master_param_out
'ParamOut'
:
param_out
s
,
'VelocityOut'
:
velocity_out
s
,
'MasterParamOut'
:
master_param_out
s
}
def
test_check_output
(
self
):
...
...
@@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
)
def
config
(
self
):
self
.
params_num
=
1
class
TestLarsMomentumOp
(
OpTest
):
def
setUp
(
self
):
self
.
config
()
self
.
op_type
=
"lars_momentum"
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
mu
=
0.0001
lars_coeff
=
0.001
lars_weight_decay
=
0.0005
params
=
[]
grads
=
[]
velocitys
=
[]
param_outs
=
[]
velocity_outs
=
[]
learning_rates
=
[]
for
i
in
range
(
self
.
params_num
):
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
pnorm
=
np
.
sqrt
(
np
.
square
(
param
).
sum
())
gnorm
=
np
.
sqrt
(
np
.
square
(
grad
).
sum
())
local_lr
=
learning_rate
*
lars_coeff
*
pnorm
/
(
gnorm
+
lars_weight_decay
*
param
)
velocity_out
=
mu
*
velocity
+
local_lr
*
(
grad
+
lars_weight_decay
*
param
)
param_out
=
param
-
velocity_out
params
.
append
((
"SubParam_"
+
str
(
i
),
param
))
grads
.
append
((
"SubGrad_"
+
str
(
i
),
grad
))
velocitys
.
append
((
"SubVelocity_"
+
str
(
i
),
velocity
))
learning_rates
.
append
((
"SubLearning_rate_"
+
str
(
i
),
learning_rate
))
velocity_outs
.
append
((
"SubVelocity_out_"
+
str
(
i
),
velocity_out
))
param_outs
.
append
((
"SubParam_out_"
+
str
(
i
),
param_out
))
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
'Param'
:
param
s
,
'Grad'
:
grad
s
,
'Velocity'
:
velocity
s
,
'LearningRate'
:
learning_rate
s
}
self
.
attrs
=
{
'mu'
:
mu
,
'lars_coeff'
:
lars_coeff
,
'lars_weight_decay'
:
lars_weight_decay
'lars_weight_decay'
:
[
lars_weight_decay
]
}
pnorm
=
np
.
sqrt
(
np
.
square
(
param
).
sum
())
gnorm
=
np
.
sqrt
(
np
.
square
(
grad
).
sum
())
local_lr
=
learning_rate
*
lars_coeff
*
pnorm
/
(
gnorm
+
lars_weight_decay
*
param
)
velocity_out
=
mu
*
velocity
+
local_lr
*
(
grad
+
lars_weight_decay
*
param
)
param_out
=
param
-
velocity_out
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
self
.
outputs
=
{
'ParamOut'
:
param_outs
,
'VelocityOut'
:
velocity_outs
}
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
()
def
config
(
self
):
self
.
params_num
=
1
class
TestSparseMomentumOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle.fluid.contrib.layers.nn
import
pow2_decay_with_linear_warmup
from
paddle.optimizer.lr
import
LinearWarmup
from
paddle.optimizer.lr
import
PolynomialDecay
import
unittest
def
gen_pow2_warmup_op_lr
(
warmup_steps
,
total_steps
,
base_lr
,
end_lr
,
place
):
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
lr
=
pow2_decay_with_linear_warmup
(
warmup_steps
,
total_steps
,
base_lr
,
end_lr
)
exe
=
paddle
.
static
.
Executor
(
place
)
with
paddle
.
static
.
scope_guard
(
paddle
.
static
.
Scope
()):
exe
.
run
(
startup
)
while
True
:
lr_np
=
exe
.
run
(
main
,
fetch_list
=
[
lr
])[
0
]
yield
lr_np
[
0
]
class
Pow2Warmup
(
LinearWarmup
):
def
__init__
(
self
,
warmup_steps
,
total_steps
,
base_lr
,
end_lr
):
assert
total_steps
>
warmup_steps
lr_sch
=
PolynomialDecay
(
learning_rate
=
base_lr
,
decay_steps
=
total_steps
-
warmup_steps
,
end_lr
=
end_lr
,
power
=
2
)
super
(
Pow2Warmup
,
self
).
__init__
(
learning_rate
=
lr_sch
,
warmup_steps
=
warmup_steps
,
start_lr
=
0.0
,
end_lr
=
base_lr
)
def
gen_pow2_warmup_py_lr
(
warmup_steps
,
total_steps
,
base_lr
,
end_lr
,
place
):
lr_sch
=
Pow2Warmup
(
warmup_steps
,
total_steps
,
base_lr
,
end_lr
)
lr_sch
.
step
()
while
True
:
yield
lr_sch
()
lr_sch
.
step
()
class
TestPow2WarmupLRScheduler
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
params
=
{
'warmup_steps'
:
30
,
'total_steps'
:
100
,
'base_lr'
:
0.02
,
'end_lr'
:
0.001
,
}
self
.
step_num
=
1000
def
check_with_place
(
self
,
place
):
kwargs
=
dict
(
self
.
params
)
kwargs
[
'place'
]
=
place
lr_sch_op
=
gen_pow2_warmup_op_lr
(
**
kwargs
)
lr_sch_py
=
gen_pow2_warmup_py_lr
(
**
kwargs
)
for
i
,
(
lr_op
,
lr_py
)
in
enumerate
(
zip
(
lr_sch_op
,
lr_sch_py
)):
self
.
assertLess
(
abs
(
lr_op
-
lr_py
),
1e-6
)
if
i
>
self
.
step_num
:
break
def
test_main
(
self
):
self
.
check_with_place
(
paddle
.
CPUPlace
())
if
paddle
.
is_compiled_with_cuda
():
self
.
check_with_place
(
paddle
.
CUDAPlace
(
0
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_tensor_copy_from.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
unittest
import
numpy
as
np
from
paddle.fluid.core
import
LoDTensor
as
Tensor
class
TestTensorCopyFrom
(
unittest
.
TestCase
):
def
test_main
(
self
):
place
=
paddle
.
CPUPlace
()
np_value
=
np
.
random
.
random
(
size
=
[
10
,
30
]).
astype
(
'float32'
)
t_src
=
Tensor
()
t_src
.
set
(
np_value
,
place
)
self
.
assertTrue
(
np
.
array_equal
(
np_value
,
t_src
))
t_dst1
=
Tensor
()
t_dst1
.
_copy_from
(
t_src
,
place
)
self
.
assertTrue
(
np
.
array_equal
(
np_value
,
t_dst1
))
t_dst2
=
Tensor
()
t_dst2
.
_copy_from
(
t_src
,
place
,
5
)
self
.
assertTrue
(
np
.
array_equal
(
np
.
array
(
np_value
[
0
:
5
]),
t_dst2
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/incubate/operators/__init__.py
浏览文件 @
287ca7d5
...
...
@@ -14,3 +14,4 @@
from
.softmax_mask_fuse_upper_triangle
import
softmax_mask_fuse_upper_triangle
# noqa: F401
from
.softmax_mask_fuse
import
softmax_mask_fuse
# noqa: F401
from
.resnet_unit
import
ResNetUnit
#noqa: F401
python/paddle/incubate/operators/resnet_unit.py
0 → 100644
浏览文件 @
287ca7d5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
collections
import
itertools
import
six
import
math
import
sys
import
warnings
from
functools
import
partial
,
reduce
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle
import
framework
from
paddle.device
import
get_device
,
get_cudnn_version
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
Layer
,
LayerList
from
paddle.fluid.layers
import
utils
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layers.utils
import
map_structure
,
flatten
,
pack_sequence_as
from
paddle.fluid.data_feeder
import
convert_dtype
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle
import
_C_ops
__all__
=
[
'resnet_unit'
,
'ResNetUnit'
]
def
resnet_unit
(
x
,
filter_x
,
scale_x
,
bias_x
,
mean_x
,
var_x
,
z
,
filter_z
,
scale_z
,
bias_z
,
mean_z
,
var_z
,
stride
,
stride_z
,
padding
,
dilation
,
groups
,
momentum
,
eps
,
data_format
,
fuse_add
,
has_shortcut
,
use_global_stats
,
is_test
,
act
):
helper
=
LayerHelper
(
'resnet_unit'
,
**
locals
())
bn_param_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
bit_mask_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
INT32
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
bit_mask
=
helper
.
create_variable_for_type_inference
(
dtype
=
bit_mask_dtype
,
stop_gradient
=
True
)
# intermediate_out for x
conv_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
,
stop_gradient
=
True
)
saved_mean_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
saved_invstd_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
running_mean_x
=
mean_x
running_var_x
=
var_x
# intermediate_out for z
conv_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
,
stop_gradient
=
True
)
saved_mean_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
saved_invstd_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
running_mean_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
if
mean_z
is
None
else
mean_z
running_var_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
if
var_z
is
None
else
var_z
inputs
=
{
'X'
:
x
,
'FilterX'
:
filter_x
,
'ScaleX'
:
scale_x
,
'BiasX'
:
bias_x
,
'MeanX'
:
mean_x
,
'VarX'
:
var_x
,
'Z'
:
z
,
'FilterZ'
:
filter_z
,
'ScaleZ'
:
scale_z
,
'BiasZ'
:
bias_z
,
'MeanZ'
:
mean_z
,
'VarZ'
:
var_z
}
attrs
=
{
'stride'
:
stride
,
'stride_z'
:
stride_z
,
'padding'
:
padding
,
'dilation'
:
dilation
,
'group'
:
groups
,
'momentum'
:
momentum
,
'epsilon'
:
eps
,
'data_format'
:
data_format
,
'fuse_add'
:
fuse_add
,
'has_shortcut'
:
has_shortcut
,
'use_global_stats'
:
use_global_stats
,
'is_test'
:
is_test
,
'act_type'
:
act
}
outputs
=
{
'Y'
:
out
,
'BitMask'
:
bit_mask
,
'ConvX'
:
conv_x
,
'SavedMeanX'
:
saved_mean_x
,
'SavedInvstdX'
:
saved_invstd_x
,
'RunningMeanX'
:
running_mean_x
,
'RunningVarX'
:
running_var_x
,
'ConvZ'
:
conv_z
,
'SavedMeanZ'
:
saved_mean_z
,
'SavedInvstdZ'
:
saved_invstd_z
,
'RunningMeanZ'
:
running_mean_z
,
'RunningVarZ'
:
running_var_z
,
}
helper
.
append_op
(
type
=
'resnet_unit'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
return
out
class
ResNetUnit
(
Layer
):
r
"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""
def
__init__
(
self
,
num_channels_x
,
num_filters
,
filter_size
,
stride
=
1
,
momentum
=
0.9
,
eps
=
1e-5
,
data_format
=
'NHWC'
,
act
=
'relu'
,
fuse_add
=
False
,
has_shortcut
=
False
,
use_global_stats
=
False
,
is_test
=
False
,
filter_x_attr
=
None
,
scale_x_attr
=
None
,
bias_x_attr
=
None
,
moving_mean_x_name
=
None
,
moving_var_x_name
=
None
,
num_channels_z
=
1
,
stride_z
=
1
,
filter_z_attr
=
None
,
scale_z_attr
=
None
,
bias_z_attr
=
None
,
moving_mean_z_name
=
None
,
moving_var_z_name
=
None
):
super
(
ResNetUnit
,
self
).
__init__
()
self
.
_stride
=
stride
self
.
_stride_z
=
stride_z
self
.
_dilation
=
1
self
.
_kernel_size
=
utils
.
convert_to_list
(
filter_size
,
2
,
'kernel_size'
)
self
.
_padding
=
(
filter_size
-
1
)
//
2
self
.
_groups
=
1
self
.
_momentum
=
momentum
self
.
_eps
=
eps
self
.
_data_format
=
data_format
self
.
_act
=
act
self
.
_fuse_add
=
fuse_add
self
.
_has_shortcut
=
has_shortcut
self
.
_use_global_stats
=
use_global_stats
self
.
_is_test
=
is_test
# check format
valid_format
=
{
'NHWC'
}
if
data_format
not
in
valid_format
:
raise
ValueError
(
"conv_format must be one of {}, but got conv_format='{}'"
.
format
(
valid_format
,
data_format
))
def
_get_default_param_initializer
(
channels
):
filter_elem_num
=
np
.
prod
(
self
.
_kernel_size
)
*
channels
std
=
(
2.0
/
filter_elem_num
)
**
0.5
return
I
.
Normal
(
0.0
,
std
)
# initial filter
bn_param_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
bn_param_shape
=
[
1
,
1
,
1
,
num_filters
]
filter_x_shape
=
[
num_filters
,
filter_size
,
filter_size
,
num_channels_x
]
filter_z_shape
=
[
num_filters
,
filter_size
,
filter_size
,
num_channels_z
]
self
.
filter_x
=
self
.
create_parameter
(
shape
=
filter_x_shape
,
attr
=
filter_x_attr
,
default_initializer
=
_get_default_param_initializer
(
num_channels_x
))
self
.
scale_x
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
scale_x_attr
,
dtype
=
bn_param_dtype
,
default_initializer
=
I
.
Constant
(
1.0
))
self
.
bias_x
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
bias_x_attr
,
dtype
=
bn_param_dtype
,
is_bias
=
True
)
self
.
mean_x
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_x_name
,
initializer
=
I
.
Constant
(
0.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
mean_x
.
stop_gradient
=
True
self
.
var_x
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_var_x_name
,
initializer
=
I
.
Constant
(
1.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
var_x
.
stop_gradient
=
True
if
has_shortcut
:
self
.
filter_z
=
self
.
create_parameter
(
shape
=
filter_z_shape
,
attr
=
filter_z_attr
,
default_initializer
=
_get_default_param_initializer
(
num_channels_z
))
self
.
scale_z
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
scale_z_attr
,
dtype
=
bn_param_dtype
,
default_initializer
=
I
.
Constant
(
1.0
))
self
.
bias_z
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
bias_z_attr
,
dtype
=
bn_param_dtype
,
is_bias
=
True
)
self
.
mean_z
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_z_name
,
initializer
=
I
.
Constant
(
0.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
mean_z
.
stop_gradient
=
True
self
.
var_z
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_var_z_name
,
initializer
=
I
.
Constant
(
1.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
var_z
.
stop_gradient
=
True
else
:
self
.
filter_z
=
None
self
.
scale_z
=
None
self
.
bias_z
=
None
self
.
mean_z
=
None
self
.
var_z
=
None
def
forward
(
self
,
x
,
z
=
None
):
if
self
.
_fuse_add
and
z
is
None
:
raise
ValueError
(
"z can not be None"
)
out
=
resnet_unit
(
x
,
self
.
filter_x
,
self
.
scale_x
,
self
.
bias_x
,
self
.
mean_x
,
self
.
var_x
,
z
,
self
.
filter_z
,
self
.
scale_z
,
self
.
bias_z
,
self
.
mean_z
,
self
.
var_z
,
self
.
_stride
,
self
.
_stride_z
,
self
.
_padding
,
self
.
_dilation
,
self
.
_groups
,
self
.
_momentum
,
self
.
_eps
,
self
.
_data_format
,
self
.
_fuse_add
,
self
.
_has_shortcut
,
self
.
_use_global_stats
,
self
.
_is_test
,
self
.
_act
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录