Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e0866dc6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0866dc6
编写于
3月 09, 2022
作者:
W
WangXi
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] fused_feedforward op support tensor model parallel (#40160)
上级
c1116b65
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
476 addition
and
5 deletion
+476
-5
paddle/fluid/operators/fused/fused_feedforward_op.cc
paddle/fluid/operators/fused/fused_feedforward_op.cc
+2
-0
paddle/fluid/operators/fused/fused_feedforward_op.cu
paddle/fluid/operators/fused/fused_feedforward_op.cu
+43
-5
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/static_model_parallel_fused_feedforward.py
...ests/unittests/static_model_parallel_fused_feedforward.py
+384
-0
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_feedforward.py
...unittests/test_static_model_parallel_fused_feedforward.py
+45
-0
未找到文件。
paddle/fluid/operators/fused/fused_feedforward_op.cc
浏览文件 @
e0866dc6
...
@@ -195,6 +195,8 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -195,6 +195,8 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"dropout1_seed"
,
"Dropout1 random seed."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dropout1_seed"
,
"Dropout1 random seed."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dropout2_seed"
,
"Dropout2 random seed."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dropout2_seed"
,
"Dropout2 random seed."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"ring id for tensor model parallel."
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
the function of fused_feedforward operator is the same as the following pseudo code:
the function of fused_feedforward operator is the same as the following pseudo code:
residual = src;
residual = src;
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cu
浏览文件 @
e0866dc6
...
@@ -21,11 +21,39 @@ limitations under the License. */
...
@@ -21,11 +21,39 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
static
void
AllReduce
(
framework
::
Tensor
&
tensor
,
// NOLINT
const
int
ring_id
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
if
(
ring_id
==
-
1
)
return
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
dtype
=
platform
::
ToNCCLDataType
(
framework
::
TransToProtoVarType
(
tensor
.
dtype
()));
int64_t
numel
=
tensor
.
numel
();
const
void
*
sendbuff
=
tensor
.
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
void
*
recvbuff
=
tensor
.
mutable_data
<
T
>
(
place
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
auto
stream
=
ctx
.
stream
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
dtype
,
ncclSum
,
comm
->
comm
(),
stream
));
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."
));
#endif
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FusedFeedForwardKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FusedFeedForwardKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -56,7 +84,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -56,7 +84,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework
::
Tensor
*
dropout1_out
,
framework
::
Tensor
*
dropout2_out
,
framework
::
Tensor
*
dropout1_out
,
framework
::
Tensor
*
dropout2_out
,
const
int
bsz_seq
,
const
int
d_model
,
const
int
dim_feedforward
,
const
int
bsz_seq
,
const
int
d_model
,
const
int
dim_feedforward
,
const
std
::
string
&
act_method
,
const
bool
pre_layer_norm
,
const
std
::
string
&
act_method
,
const
bool
pre_layer_norm
,
const
float
epsilon1
,
const
float
epsilon2
,
const
float
epsilon1
,
const
float
epsilon2
,
const
int
ring_id
,
const
DropoutParam
&
dropout_param1
,
const
DropoutParam
&
dropout_param1
,
const
DropoutParam
&
dropout_param2
,
const
DropoutParam
&
dropout_param2
,
const
platform
::
CUDADeviceContext
&
ctx
)
const
{
const
platform
::
CUDADeviceContext
&
ctx
)
const
{
...
@@ -95,6 +123,10 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -95,6 +123,10 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework
::
Tensor
linear2_out
;
framework
::
Tensor
linear2_out
;
linear2_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
linear2_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
MatMul
(
ctx
,
*
dropout1_out
,
linear2_weight
,
&
linear2_out
);
MatMul
(
ctx
,
*
dropout1_out
,
linear2_weight
,
&
linear2_out
);
// tensor model parallel
AllReduce
<
T
>
(
linear2_out
,
ring_id
,
ctx
);
if
(
!
pre_layer_norm
)
{
if
(
!
pre_layer_norm
)
{
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
ctx
,
linear2_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
linear2_bias_ptr
,
ctx
,
linear2_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
linear2_bias_ptr
,
...
@@ -150,6 +182,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -150,6 +182,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
int
ring_id
=
context
.
Attr
<
int
>
(
"ring_id"
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param2
(
context
,
2
);
DropoutParam
dropout_param2
(
context
,
2
);
...
@@ -186,7 +219,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -186,7 +219,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
dropout2_mask
,
ln1_mean
,
ln1_variance
,
ln2_mean
,
ln2_variance
,
dropout2_mask
,
ln1_mean
,
ln1_variance
,
ln2_mean
,
ln2_variance
,
linear1_out
,
ln1_out
,
dropout1_out
,
dropout2_out
,
bsz_seq
,
d_model
,
linear1_out
,
ln1_out
,
dropout1_out
,
dropout2_out
,
bsz_seq
,
d_model
,
dim_feedforward
,
act_method
,
pre_layer_norm
,
epsilon1
,
epsilon2
,
dim_feedforward
,
act_method
,
pre_layer_norm
,
epsilon1
,
epsilon2
,
dropout_param1
,
dropout_param2
,
context
.
cuda_device_context
());
ring_id
,
dropout_param1
,
dropout_param2
,
context
.
cuda_device_context
());
}
}
};
};
...
@@ -231,7 +264,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -231,7 +264,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
int
dim_feedforward
,
const
DropoutParam
&
dropout_param1
,
const
int
dim_feedforward
,
const
DropoutParam
&
dropout_param1
,
const
DropoutParam
&
dropout_param2
,
const
std
::
string
&
act_method
,
const
DropoutParam
&
dropout_param2
,
const
std
::
string
&
act_method
,
const
bool
pre_layer_norm
,
const
float
epsilon1
,
const
float
epsilon2
,
const
bool
pre_layer_norm
,
const
float
epsilon1
,
const
float
epsilon2
,
const
platform
::
CUDADeviceContext
&
ctx
)
const
{
const
int
ring_id
,
const
platform
::
CUDADeviceContext
&
ctx
)
const
{
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
pre_layernorm_helper
(
FusedDropoutLayerNormHelper
<
T
,
uint8_t
>
pre_layernorm_helper
(
bsz_seq
,
d_model
,
epsilon1
);
bsz_seq
,
d_model
,
epsilon1
);
FusedDropoutHelper
<
T
,
uint8_t
>
fused_act_dropout_helper
(
FusedDropoutHelper
<
T
,
uint8_t
>
fused_act_dropout_helper
(
...
@@ -295,13 +328,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -295,13 +328,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
MatMulGrad
(
ctx
,
d_linear1_out
,
*
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
MatMulGrad
(
ctx
,
d_linear1_out
,
*
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
d_linear1_weight
);
d_linear1_weight
);
// tensor model parallel
AllReduce
<
T
>
(
d_ln1_out
,
ring_id
,
ctx
);
pre_layernorm_helper
.
LayerNormGrad
(
pre_layernorm_helper
.
LayerNormGrad
(
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
ln1_gamma_ptr
,
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
ln1_gamma_ptr
,
ln1_mean
->
data
<
U
>
(),
ln1_variance
->
data
<
U
>
(),
d_x
->
data
<
T
>
(),
ln1_mean
->
data
<
U
>
(),
ln1_variance
->
data
<
U
>
(),
d_x
->
data
<
T
>
(),
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
}
else
{
}
else
{
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
// tensor model parallel
AllReduce
<
T
>
(
*
d_x
,
ring_id
,
ctx
);
}
}
std
::
vector
<
const
Tensor
*>
ins
(
2
);
std
::
vector
<
const
Tensor
*>
ins
(
2
);
std
::
vector
<
Tensor
*>
outs
(
1
);
std
::
vector
<
Tensor
*>
outs
(
1
);
...
@@ -376,6 +412,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -376,6 +412,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
int
ring_id
=
context
.
Attr
<
int
>
(
"ring_id"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param2
(
context
,
2
);
DropoutParam
dropout_param2
(
context
,
2
);
...
@@ -419,7 +456,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -419,7 +456,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
d_linear1_bias
,
d_linear2_weight
,
d_linear2_bias
,
d_ln1_scale
,
d_linear1_bias
,
d_linear2_weight
,
d_linear2_bias
,
d_ln1_scale
,
d_ln1_bias
,
d_ln2_scale
,
d_ln2_bias
,
bsz_seq
,
d_model
,
d_ln1_bias
,
d_ln2_scale
,
d_ln2_bias
,
bsz_seq
,
d_model
,
dim_feedforward
,
dropout_param1
,
dropout_param2
,
act_method
,
dim_feedforward
,
dropout_param1
,
dropout_param2
,
act_method
,
pre_layer_norm
,
epsilon1
,
epsilon2
,
context
.
cuda_device_context
());
pre_layer_norm
,
epsilon1
,
epsilon2
,
ring_id
,
context
.
cuda_device_context
());
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
e0866dc6
...
@@ -23,6 +23,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist)
...
@@ -23,6 +23,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist)
list
(
APPEND DIST_TEST_OPS test_pipeline
)
list
(
APPEND DIST_TEST_OPS test_pipeline
)
list
(
APPEND DIST_TEST_OPS test_ir_pass_pipeline
)
list
(
APPEND DIST_TEST_OPS test_ir_pass_pipeline
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
...
@@ -1150,6 +1151,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
...
@@ -1150,6 +1151,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties
(
test_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_ir_pass_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_ir_pass_pipeline PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_static_model_parallel PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_static_model_parallel PROPERTIES TIMEOUT 240
)
set_tests_properties
(
test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_split_embedding
set_tests_properties
(
test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
test_collective_split_row_linear
...
...
python/paddle/fluid/tests/unittests/static_model_parallel_fused_feedforward.py
0 → 100644
浏览文件 @
e0866dc6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
paddle.distributed.fleet
as
fleet
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn.initializer
import
Constant
paddle
.
enable_static
()
DTYPE
=
"float32"
MODEL_PARALLEL_SIZE
=
2
IN_SIZE
=
2
*
MODEL_PARALLEL_SIZE
OUT_SIZE
=
2
*
MODEL_PARALLEL_SIZE
def
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
linear1_bias
=
None
,
linear2_bias
=
None
,
ln1_scale
=
None
,
ln1_bias
=
None
,
ln2_scale
=
None
,
ln2_bias
=
None
,
dropout1_rate
=
0.5
,
dropout2_rate
=
0.5
,
activation
=
"relu"
,
ln1_epsilon
=
1e-5
,
ln2_epsilon
=
1e-5
,
pre_layer_norm
=
False
,
training
=
True
,
mode
=
'upscale_in_train'
,
ring_id
=-
1
,
name
=
None
):
seed
=
None
if
mode
not
in
(
'downscale_in_infer'
,
'upscale_in_train'
):
raise
ValueError
(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
mode
=
'downgrade_in_infer'
if
mode
==
'downscale_in_infer'
else
mode
#semantic transfer
helper
=
LayerHelper
(
"fused_feedforward"
)
dtype
=
x
.
dtype
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'fused_feedforward'
)
check_dtype
(
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'fused_feedforward'
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
dropout1_mask
=
helper
.
create_variable_for_type_inference
(
'uint8'
,
stop_gradient
=
True
)
dropout2_mask
=
helper
.
create_variable_for_type_inference
(
'uint8'
,
stop_gradient
=
True
)
ln1_mean
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
ln1_variance
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
ln2_mean
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
ln2_variance
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
linear1_out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
ln1_out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
dropout1_out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
dropout2_out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
,
stop_gradient
=
True
)
if
(
seed
is
None
or
seed
==
0
)
and
helper
.
main_program
.
random_seed
!=
0
:
seed
=
helper
.
main_program
.
random_seed
helper
.
append_op
(
type
=
'fused_feedforward'
,
inputs
=
{
'X'
:
x
,
'Linear1Weight'
:
linear1_weight
,
'Linear1Bias'
:
linear1_bias
,
'Linear2Weight'
:
linear2_weight
,
'Linear2Bias'
:
linear2_bias
,
'Ln1Scale'
:
ln1_scale
,
'Ln1Bias'
:
ln1_bias
,
'Ln2Scale'
:
ln2_scale
,
'Ln2Bias'
:
ln2_bias
,
},
outputs
=
{
'Out'
:
out
,
'Dropout1Mask'
:
dropout1_mask
,
'Dropout2Mask'
:
dropout2_mask
,
'Ln1Mean'
:
ln1_mean
,
'Ln1Variance'
:
ln1_variance
,
'Ln2Mean'
:
ln2_mean
,
'Ln2Variance'
:
ln2_variance
,
'Linear1Out'
:
linear1_out
,
'Ln1Out'
:
ln1_out
,
'Dropout1Out'
:
dropout1_out
,
'Dropout2Out'
:
dropout2_out
,
},
attrs
=
{
'dropout1_rate'
:
dropout1_rate
,
'dropout2_rate'
:
dropout2_rate
,
'act_method'
:
activation
,
'pre_layer_norm'
:
pre_layer_norm
,
'ln1_epsilon'
:
ln1_epsilon
,
'ln2_epsilon'
:
ln2_epsilon
,
'dropout1_is_test'
:
not
training
,
'dropout2_is_test'
:
not
training
,
'dropout1_fix_seed'
:
seed
is
not
None
,
'dropout2_fix_seed'
:
seed
is
not
None
,
'dropout1_seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout2_seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout1_implementation'
:
mode
,
'dropout2_implementation'
:
mode
,
'ring_id'
:
ring_id
,
})
return
out
def
_set_var_distributed
(
var
):
if
var
is
None
:
return
var
.
is_distributed
=
True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block
=
paddle
.
static
.
default_startup_program
().
current_block
()
main_block
=
paddle
.
static
.
default_main_program
().
current_block
()
startup_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
main_block
.
_find_var_recursive
(
var
.
name
).
is_distributed
=
True
class
ParallelFusedFeedForward
(
Layer
):
def
__init__
(
self
,
d_model
,
dim_feedforward
,
dropout_rate
=
0.1
,
epsilon
=
1e-05
,
activation
=
"relu"
,
act_dropout_rate
=
None
,
normalize_before
=
False
,
linear1_weight_attr
=
None
,
linear1_bias_attr
=
None
,
linear2_weight_attr
=
None
,
linear2_bias_attr
=
None
,
ln1_scale_attr
=
None
,
ln1_bias_attr
=
None
,
ln2_scale_attr
=
None
,
ln2_bias_attr
=
None
,
nranks
=
1
,
ring_id
=-
1
,
name
=
None
):
super
(
ParallelFusedFeedForward
,
self
).
__init__
()
assert
d_model
>
0
,
(
"Expected d_model to be greater than 0, but recieved {}"
.
format
(
d_model
))
assert
dim_feedforward
>
0
,
(
"Expected dim_feedforward to be greater than 0, but recieved {}"
.
format
(
dim_feedforward
))
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_d_model
=
d_model
assert
dim_feedforward
%
nranks
==
0
dim_feedforward
=
dim_feedforward
//
nranks
self
.
_dim_feedforward
=
dim_feedforward
self
.
_dropout_rate
=
dropout_rate
self
.
_act_dropout_rate
=
dropout_rate
if
act_dropout_rate
is
None
else
act_dropout_rate
self
.
_act_method
=
activation
self
.
_normalize_before
=
normalize_before
self
.
_epsilon
=
epsilon
self
.
_ring_id
=
ring_id
self
.
_linear1_weight
=
self
.
create_parameter
(
shape
=
[
d_model
,
dim_feedforward
],
attr
=
linear1_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
_linear1_bias
=
self
.
create_parameter
(
shape
=
[
dim_feedforward
],
attr
=
linear1_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
_linear2_weight
=
self
.
create_parameter
(
shape
=
[
dim_feedforward
,
d_model
],
attr
=
linear2_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
_linear2_bias
=
self
.
create_parameter
(
shape
=
[
d_model
],
attr
=
linear2_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
if
nranks
>
1
:
assert
ring_id
!=
-
1
# column parallel
_set_var_distributed
(
self
.
_linear1_weight
)
_set_var_distributed
(
self
.
_linear1_bias
)
_set_var_distributed
(
self
.
_linear2_weight
)
if
normalize_before
:
self
.
_ln1_scale
=
self
.
create_parameter
(
shape
=
[
d_model
],
attr
=
ln1_scale_attr
,
is_bias
=
False
,
default_initializer
=
Constant
(
1.0
))
self
.
_ln1_bias
=
self
.
create_parameter
(
shape
=
[
d_model
],
attr
=
ln1_bias_attr
,
is_bias
=
True
)
self
.
_ln2_scale
=
None
self
.
_ln2_bias
=
None
else
:
self
.
_ln1_bias
=
None
self
.
_ln2_bias
=
None
self
.
_ln2_scale
=
self
.
create_parameter
(
shape
=
[
d_model
],
attr
=
ln2_scale_attr
,
is_bias
=
False
,
default_initializer
=
Constant
(
1.0
))
self
.
_ln2_bias
=
self
.
create_parameter
(
shape
=
[
d_model
],
attr
=
ln2_bias_attr
,
is_bias
=
True
)
self
.
name
=
name
def
forward
(
self
,
src
,
cache
=
None
):
out
=
fused_feedforward
(
src
,
self
.
_linear1_weight
,
self
.
_linear2_weight
,
self
.
_linear1_bias
,
self
.
_linear2_bias
,
self
.
_ln1_scale
,
self
.
_ln1_bias
,
self
.
_ln2_scale
,
self
.
_ln2_bias
,
dropout1_rate
=
self
.
_act_dropout_rate
,
dropout2_rate
=
self
.
_dropout_rate
,
activation
=
self
.
_act_method
,
ln1_epsilon
=
self
.
_epsilon
,
ln2_epsilon
=
self
.
_epsilon
,
pre_layer_norm
=
self
.
_normalize_before
,
training
=
self
.
training
,
ring_id
=
self
.
_ring_id
,
name
=
self
.
name
)
return
out
def
get_param_attr
(
weight
,
bias
):
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
weight
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
bias
))
return
weight_attr
,
bias_attr
def
create_model
(
data
,
rank
):
np
.
random
.
seed
(
2021
)
ln_w
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
IN_SIZE
,
)).
astype
(
DTYPE
)
ln_b
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
IN_SIZE
,
)).
astype
(
DTYPE
)
w0
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
IN_SIZE
,
OUT_SIZE
)).
astype
(
DTYPE
)
b0
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
OUT_SIZE
,
)).
astype
(
DTYPE
)
w1
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
OUT_SIZE
,
IN_SIZE
)).
astype
(
DTYPE
)
b1
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
IN_SIZE
,
)).
astype
(
DTYPE
)
data
.
stop_gradient
=
False
if
rank
is
not
None
:
start
=
0
if
rank
==
0
else
OUT_SIZE
//
MODEL_PARALLEL_SIZE
end
=
start
+
OUT_SIZE
//
MODEL_PARALLEL_SIZE
col_w0
=
w0
[:,
start
:
end
]
col_b0
=
b0
[
start
:
end
]
row_w1
=
w1
[
start
:
end
,
:]
ln_w_attr
,
ln_b_attr
=
get_param_attr
(
ln_w
,
ln_b
)
w0_attr
,
b0_attr
=
get_param_attr
(
col_w0
,
col_b0
)
w1_attr
,
b1_attr
=
get_param_attr
(
row_w1
,
b1
)
ffn
=
ParallelFusedFeedForward
(
IN_SIZE
,
OUT_SIZE
,
dropout_rate
=
0.0
,
activation
=
'gelu'
,
normalize_before
=
True
,
linear1_weight_attr
=
w0_attr
,
linear1_bias_attr
=
b0_attr
,
linear2_weight_attr
=
w1_attr
,
linear2_bias_attr
=
b1_attr
,
ln1_scale_attr
=
ln_w_attr
,
ln1_bias_attr
=
ln_b_attr
,
nranks
=
MODEL_PARALLEL_SIZE
,
ring_id
=
0
)
#ffn.eval()
result
=
ffn
(
data
)
else
:
ln_w_attr
,
ln_b_attr
=
get_param_attr
(
ln_w
,
ln_b
)
w0_attr
,
b0_attr
=
get_param_attr
(
w0
,
b0
)
w1_attr
,
b1_attr
=
get_param_attr
(
w1
,
b1
)
ffn
=
ParallelFusedFeedForward
(
IN_SIZE
,
OUT_SIZE
,
dropout_rate
=
0.0
,
activation
=
'gelu'
,
normalize_before
=
True
,
linear1_weight_attr
=
w0_attr
,
linear1_bias_attr
=
b0_attr
,
linear2_weight_attr
=
w1_attr
,
linear2_bias_attr
=
b1_attr
,
ln1_scale_attr
=
ln_w_attr
,
ln1_bias_attr
=
ln_b_attr
)
#ffn.eval()
result
=
ffn
(
data
)
predict
=
paddle
.
sum
(
result
)
return
predict
class
TestModelParallel
(
TestDistRunnerBase
):
def
get_model
(
self
,
batch_size
=
2
,
use_dgc
=
False
,
dist_strategy
=
None
):
# Input data
seq_len
=
2
data_in
=
fluid
.
data
(
name
=
'data_in'
,
shape
=
[
batch_size
,
seq_len
,
IN_SIZE
],
dtype
=
DTYPE
)
if
dist_strategy
:
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
data_in
],
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
if
dist_strategy
:
fleet
.
init
(
is_collective
=
True
)
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
tensor_parallel
=
True
strategy
.
tensor_parallel_configs
=
{
'tensor_parallel_degree'
:
2
}
rank
=
fleet
.
worker_index
()
if
dist_strategy
else
None
avg_cost
=
create_model
(
data_in
,
rank
)
opt
=
fluid
.
optimizer
.
SGD
(
0.1
)
if
dist_strategy
:
dist_opt
=
fleet
.
distributed_optimizer
(
optimizer
=
opt
,
strategy
=
strategy
)
dist_opt
.
minimize
(
avg_cost
)
else
:
opt
.
minimize
(
avg_cost
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
while
True
:
data
=
[
np
.
random
.
random
([
seq_len
,
IN_SIZE
]).
astype
(
DTYPE
)]
yield
data
train_reader
=
paddle
.
batch
(
gen_data
,
batch_size
=
batch_size
)
if
dist_strategy
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
,
data_loader
else
:
return
None
,
avg_cost
,
train_reader
,
None
,
None
,
None
if
__name__
==
"__main__"
:
runtime_main
(
TestModelParallel
)
python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_feedforward.py
0 → 100644
浏览文件 @
e0866dc6
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
import
os
import
paddle
paddle
.
enable_static
()
flag_name
=
os
.
path
.
splitext
(
__file__
)[
0
]
class
TestStaticModelParallel
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
self
.
_use_reader_alloc
=
False
self
.
_nccl_comm_num
=
1
self
.
_pipeline_mode
=
True
def
test_dist_static_model_parallel_fused_feedforward
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"static_model_parallel_fused_feedforward.py"
,
delta
=
1e-5
,
check_error_log
=
True
,
log_name
=
flag_name
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录