Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3a59ede9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a59ede9
编写于
7月 01, 2022
作者:
C
Chenxiao Niu
提交者:
GitHub
7月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] add rnn backward kernel. (#43969)
上级
9f397f16
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
522 addition
and
36 deletion
+522
-36
paddle/fluid/operators/mlu/mlu_baseop.cc
paddle/fluid/operators/mlu/mlu_baseop.cc
+82
-0
paddle/fluid/operators/mlu/mlu_baseop.h
paddle/fluid/operators/mlu/mlu_baseop.h
+24
-0
paddle/fluid/operators/rnn_op_mlu.cc
paddle/fluid/operators/rnn_op_mlu.cc
+385
-12
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
+31
-24
未找到文件。
paddle/fluid/operators/mlu/mlu_baseop.cc
浏览文件 @
3a59ede9
...
...
@@ -4616,6 +4616,88 @@ MLURNNDesc::~MLURNNDesc() {
reservespace_size
));
}
/* static */
void
MLUCnnl
::
RNNBackward
(
const
ExecutionContext
&
ctx
,
const
cnnlRNNDescriptor_t
rnn_desc
,
cnnlWgradMode_t
add_grad
,
const
int
dev_seq_lengths
[],
const
void
*
weight_param_ptr
,
void
*
dweight_param_ptr
,
size_t
weightspace_size
,
const
cnnlSeqDataDescriptor_t
x_desc
,
const
void
*
x
,
void
*
dx
,
const
cnnlSeqDataDescriptor_t
y_desc
,
const
void
*
y
,
const
void
*
dy
,
const
cnnlTensorDescriptor_t
hx_desc
,
const
void
*
hx
,
const
void
*
dhy
,
void
*
dhx
,
const
cnnlTensorDescriptor_t
cx_desc
,
const
void
*
cx
,
const
void
*
dcy
,
void
*
dcx
,
void
*
reservespace_ptr
,
size_t
reservespace_size
)
{
cnnlHandle_t
handle
=
GetHandleFromCTX
(
ctx
);
PADDLE_ENFORCE_NOT_NULL
(
rnn_desc
,
paddle
::
platform
::
errors
::
Fatal
(
"MLU RNNForward failed. rnn_desc initializing failed."
));
PADDLE_ENFORCE_NOT_NULL
(
x_desc
,
paddle
::
platform
::
errors
::
Fatal
(
"MLU RNNForward failed. x_desc initializing failed."
));
auto
&
dev_ctx
=
GetDevCtxFromCTX
(
ctx
);
size_t
workspace_size
;
Tensor
workspace
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetRNNTempSizes
(
handle
,
rnn_desc
,
x_desc
,
&
workspace_size
,
&
reservespace_size
));
workspace
=
ctx
.
AllocateTmpTensor
<
int8_t
,
MLUDeviceContext
>
(
{
static_cast
<
int64_t
>
(
workspace_size
)},
dev_ctx
);
void
*
workspace_ptr
=
workspace
.
mutable_data
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlRNNBackwardData
(
handle
,
rnn_desc
,
dev_seq_lengths
,
y_desc
,
y
,
dy
,
x_desc
,
dx
,
hx_desc
,
hx
,
dhy
,
dhx
,
cx_desc
,
cx
,
dcy
,
dcx
,
weight_param_ptr
,
weightspace_size
,
workspace_ptr
,
workspace_size
,
reservespace_ptr
,
reservespace_size
));
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlRNNBackwardWeights
(
handle
,
rnn_desc
,
add_grad
,
dev_seq_lengths
,
x_desc
,
x
,
hx_desc
,
hx
,
y_desc
,
y
,
dweight_param_ptr
,
weightspace_size
,
workspace_ptr
,
workspace_size
,
reservespace_ptr
,
reservespace_size
));
}
/* static */
void
MLUCnnl
::
Mask
(
const
ExecutionContext
&
ctx
,
cnnlMaskedOp_t
masked_mode
,
const
cnnlTensorDescriptor_t
input_desc
,
...
...
paddle/fluid/operators/mlu/mlu_baseop.h
浏览文件 @
3a59ede9
...
...
@@ -1924,6 +1924,30 @@ class MLUCnnl {
void
*
cy
,
void
*
reservespace_ptr
);
static
void
RNNBackward
(
const
ExecutionContext
&
ctx
,
const
cnnlRNNDescriptor_t
rnn_desc
,
cnnlWgradMode_t
add_grad
,
const
int
dev_seq_lengths
[],
const
void
*
weight_param_ptr
,
void
*
dweight_param_ptr
,
size_t
weightspace_size
,
const
cnnlSeqDataDescriptor_t
x_desc
,
const
void
*
x
,
void
*
dx
,
const
cnnlSeqDataDescriptor_t
y_desc
,
const
void
*
y
,
const
void
*
dy
,
const
cnnlTensorDescriptor_t
hx_desc
,
const
void
*
hx
,
const
void
*
dhy
,
void
*
dhx
,
const
cnnlTensorDescriptor_t
cx_desc
,
const
void
*
cx
,
const
void
*
dcy
,
void
*
dcx
,
void
*
reservespace_ptr
,
size_t
reservespace_size
);
static
void
Mask
(
const
ExecutionContext
&
ctx
,
cnnlMaskedOp_t
masked_mode
,
const
cnnlTensorDescriptor_t
input_desc
,
...
...
paddle/fluid/operators/rnn_op_mlu.cc
浏览文件 @
3a59ede9
...
...
@@ -28,7 +28,7 @@ void reset_parameter_vector(
const
std
::
vector
<
TensorType
>&
raw_params_vec
,
const
int
&
num_layers
,
const
bool
&
is_bidirec
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
const
T
*
,
size_t
>>>*
params_vec
)
{
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>*
params_vec
)
{
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
...
...
@@ -47,7 +47,8 @@ void reset_parameter_vector(
}
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
params_vec
->
at
(
i
)[
j
]
=
std
::
make_pair
(
raw_params_vec
[
tensor_idx
]
->
template
data
<
remove_cv_t
>(),
const_cast
<
T
*>
(
raw_params_vec
[
tensor_idx
]
->
template
data
<
remove_cv_t
>()),
raw_params_vec
[
tensor_idx
]
->
numel
()
*
sizeof
(
T
));
}
}
...
...
@@ -66,7 +67,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
// Output
auto
state
=
ctx
.
MultiOutput
<
Tensor
>
(
"State"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto
*
reserve_data
=
ctx
.
Output
<
Tensor
>
(
"Reserve"
);
// Attributes
const
int
&
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
...
...
@@ -79,14 +79,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
sequence_length
=
ctx
.
Input
<
Tensor
>
(
"SequenceLength"
);
}
// if (dropout_mask->IsInitialized()) {
// if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
// }
// dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
// auto& dev_ctx = ctx.template device_context<DeviceContext>();
// phi::funcs::SetConstant<platform::XPUDeviceContext, uint8_t> ones;
// ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
auto
init_h
=
pre_state
[
0
];
// -> hx
auto
init_c
=
pre_state
[
1
];
// -> cx
auto
last_h
=
state
[
0
];
...
...
@@ -143,7 +135,7 @@ class RNNMLUKernel : public framework::OpKernel<T> {
init_c
->
dims
()[
0
]));
// weightlist
std
::
vector
<
std
::
vector
<
std
::
pair
<
const
T
*
,
size_t
>>>
parameter_lists
;
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists
;
parameter_lists
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_list
,
num_layers
,
is_bidirec
,
&
parameter_lists
);
...
...
@@ -363,9 +355,390 @@ class RNNMLUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
RNNMLUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
stream
=
ctx
.
template
device_context
<
MLUDeviceContext
>().
stream
();
// get the tensor pointer for the input
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
pre_state
=
ctx
.
MultiInput
<
Tensor
>
(
"PreState"
);
auto
weight_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"WeightList"
);
auto
*
output
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
auto
*
reserve_data
=
ctx
.
Input
<
Tensor
>
(
"Reserve"
);
const
int
&
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
const
bool
&
is_bidirec
=
ctx
.
Attr
<
bool
>
(
"is_bidirec"
);
const
int
&
hidden_size
=
ctx
.
Attr
<
int
>
(
"hidden_size"
);
const
std
::
string
&
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
bool
has_seq_length
=
ctx
.
HasInput
(
"SequenceLength"
);
const
Tensor
*
sequence_length
=
nullptr
;
if
(
has_seq_length
)
{
sequence_length
=
ctx
.
Input
<
Tensor
>
(
"SequenceLength"
);
}
PADDLE_ENFORCE_EQ
(
mode
,
"LSTM"
,
platform
::
errors
::
InvalidArgument
(
"XPU only support LSTM mode now, current mode is %s"
,
mode
));
auto
init_h
=
pre_state
[
0
];
// -> hx
auto
init_c
=
pre_state
[
1
];
// -> cx
auto
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
state_grad
=
ctx
.
MultiInput
<
Tensor
>
(
framework
::
GradVarName
(
"State"
));
auto
last_h_grad
=
state_grad
[
0
];
// -> dhy
auto
last_c_grad
=
state_grad
[
1
];
// -> dcy
// get the tensor pointer for the output
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
weight_grad_list
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"WeightList"
));
auto
pre_state_grad
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"PreState"
));
Tensor
*
init_h_grad
=
nullptr
;
Tensor
*
init_c_grad
=
nullptr
;
if
(
pre_state_grad
.
size
()
>
0
)
{
// has gradient
init_h_grad
=
pre_state_grad
[
0
];
// -> dhx
init_c_grad
=
pre_state_grad
[
1
];
// -> dcx
}
// check shape
const
int
in_out_dim_num
=
input
->
dims
().
size
();
const
int
&
seq_len
=
input
->
dims
()[
0
];
const
int
&
batch_size
=
input
->
dims
()[
1
];
const
int
&
input_dim
=
input
->
dims
()[
2
];
const
int
&
direction_num
=
is_bidirec
?
2
:
1
;
int
in_dim_arr
[
in_out_dim_num
]
=
{
seq_len
,
batch_size
,
input_dim
};
int
out_dim_arr
[
in_out_dim_num
]
=
{
seq_len
,
batch_size
,
direction_num
*
hidden_size
};
int
proj_size
=
hidden_size
;
PADDLE_ENFORCE_EQ
(
num_layers
,
1
,
platform
::
errors
::
InvalidArgument
(
"MLU only support 1 num_layers, current num_layers is %s"
,
num_layers
));
PADDLE_ENFORCE_EQ
(
init_h
->
dims
()[
0
],
num_layers
*
direction_num
,
platform
::
errors
::
InvalidArgument
(
"The num_layers of in RNN layer must"
" be the same as first dim of init"
"hidden, but received num_layers:%d,"
" dim:%d"
,
num_layers
,
init_h
->
dims
()[
0
]));
PADDLE_ENFORCE_EQ
(
init_c
->
dims
()[
0
],
num_layers
*
direction_num
,
platform
::
errors
::
InvalidArgument
(
"The num_layers of in RNN layer must"
" be the same as first dim of cell state hidden, but received"
" num_layers:%d, dim:%d"
,
num_layers
,
init_c
->
dims
()[
0
]));
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists
;
parameter_lists
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_list
,
num_layers
,
is_bidirec
,
&
parameter_lists
);
for
(
unsigned
int
i
=
0
;
i
<
weight_grad_list
.
size
();
++
i
)
{
weight_grad_list
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists_grad
;
parameter_lists_grad
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_grad_list
,
num_layers
,
is_bidirec
,
&
parameter_lists_grad
);
// allocate the memory and initization the input_grad
input_grad
->
mutable_data
<
T
>
(
input
->
dims
(),
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
input_grad
);
Tensor
a
,
b
;
Tensor
*
dynamic_grad_pre_h
=
&
a
;
Tensor
*
dynamic_grad_pre_c
=
&
b
;
if
(
init_h_grad
)
{
init_h_grad
->
mutable_data
<
T
>
(
last_h_grad
->
dims
(),
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
init_h_grad
);
}
else
{
dynamic_grad_pre_h
->
Resize
(
last_h_grad
->
dims
());
dynamic_grad_pre_h
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
dynamic_grad_pre_h
);
init_h_grad
=
dynamic_grad_pre_h
;
}
if
(
init_c_grad
)
{
init_c_grad
->
mutable_data
<
T
>
(
last_c_grad
->
dims
(),
ctx
.
GetPlace
());
}
else
{
dynamic_grad_pre_c
->
Resize
(
last_h_grad
->
dims
());
dynamic_grad_pre_c
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
init_c_grad
=
dynamic_grad_pre_c
;
}
std
::
vector
<
int
>
seq_len_vec
(
batch_size
,
seq_len
);
if
(
has_seq_length
)
{
seq_len_vec
=
operators
::
GetDataFromTensor
(
sequence_length
);
}
cnnlDirectionMode_t
direction
=
is_bidirec
?
CNNL_RNN_BIDIRECTIONAL
:
CNNL_RNN_UNIDIRECTIONAL
;
MLUSeqDataDesc
input_seq_data_desc
(
CNNL_SEQDATA_TNC
,
ToCnnlDataType
(
input
->
dtype
()),
in_out_dim_num
,
in_dim_arr
,
static_cast
<
int
>
(
seq_len_vec
.
size
()),
seq_len_vec
.
data
(),
nullptr
);
MLUSeqDataDesc
out_seq_data_desc
(
CNNL_SEQDATA_TNC
,
ToCnnlDataType
(
input
->
dtype
()),
in_out_dim_num
,
out_dim_arr
,
static_cast
<
int
>
(
seq_len_vec
.
size
()),
seq_len_vec
.
data
(),
nullptr
);
MLUCnnlTensorDesc
hx_desc
(
*
init_h
);
MLUCnnlTensorDesc
cx_desc
(
*
init_c
);
MLURNNDesc
rnn_desc
(
CNNL_LSTM
,
CNNL_RNN_DOUBLE_BIAS
,
direction
,
CNNL_RNN_LINEAR_INPUT
,
ToCnnlDataType
(
input
->
dtype
()),
ToCnnlDataType
(
input
->
dtype
()),
input_dim
,
hidden_size
,
/*projection*/
proj_size
,
num_layers
,
nullptr
,
CNNL_RNN_PADDED_IO_DISABLED
);
rnn_desc
.
SetRNNMaskMode
(
CNNL_LSTM_MASK_ENABLED
);
// copy weight
size_t
weightspace_size
;
framework
::
Tensor
weightspace
,
dweightspace
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetRNNWeightSpaceSize
(
GetHandleFromCTX
(
ctx
),
rnn_desc
.
get
(),
&
weightspace_size
));
weightspace
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
{
static_cast
<
int64_t
>
(
weightspace_size
)},
dev_ctx
);
dweightspace
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
{
static_cast
<
int64_t
>
(
weightspace_size
)},
dev_ctx
);
void
*
weightspace_ptr
=
weightspace
.
mutable_data
(
ctx
.
GetPlace
());
auto
w_x
=
parameter_lists
[
0
][
0
];
auto
w_h
=
parameter_lists
[
0
][
1
];
auto
b_x
=
parameter_lists
[
0
][
2
];
auto
b_h
=
parameter_lists
[
0
][
3
];
auto
actual_total_w_size
=
w_x
.
second
+
w_h
.
second
+
b_x
.
second
+
b_h
.
second
;
void
*
w_x_ptr
=
weightspace_ptr
;
void
*
w_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
;
void
*
b_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
+
w_h
.
second
;
void
*
b_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
+
w_h
.
second
+
b_x
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
w_x_ptr
,
weightspace
.
place
(),
w_x
.
first
,
w_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
w_h_ptr
,
weightspace
.
place
(),
w_h
.
first
,
w_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
b_x_ptr
,
weightspace
.
place
(),
b_x
.
first
,
b_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
b_h_ptr
,
weightspace
.
place
(),
b_h
.
first
,
b_h
.
second
,
stream
);
if
(
is_bidirec
)
{
auto
bw_x
=
parameter_lists
[
0
][
4
];
auto
bw_h
=
parameter_lists
[
0
][
5
];
auto
bb_x
=
parameter_lists
[
0
][
6
];
auto
bb_h
=
parameter_lists
[
0
][
7
];
void
*
bw_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
;
void
*
bw_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
;
void
*
bb_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
+
bw_h
.
second
;
void
*
bb_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
+
bw_h
.
second
+
bb_x
.
second
;
actual_total_w_size
+=
bw_x
.
second
+
bw_h
.
second
+
bb_x
.
second
+
bb_h
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
bw_x_ptr
,
weightspace
.
place
(),
bw_x
.
first
,
bw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bw_h_ptr
,
weightspace
.
place
(),
bw_h
.
first
,
bw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bb_x_ptr
,
weightspace
.
place
(),
bb_x
.
first
,
bb_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bb_h_ptr
,
weightspace
.
place
(),
bb_h
.
first
,
bb_h
.
second
,
stream
);
}
dev_ctx
.
Wait
();
PADDLE_ENFORCE_EQ
(
weightspace_size
,
actual_total_w_size
,
platform
::
errors
::
InvalidArgument
(
"The weightsize doesn't match"
" weightspace_size:%d, actual_total_w_size:%d"
,
weightspace_size
,
actual_total_w_size
));
MLUCnnl
::
RNNBackward
(
ctx
,
rnn_desc
.
get
(),
CNNL_WGRAD_MODE_SET
,
seq_len_vec
.
data
(),
GetBasePtr
(
&
weightspace
),
GetBasePtr
(
&
dweightspace
),
weightspace
.
numel
()
*
sizeof
(
T
),
input_seq_data_desc
.
get
(),
GetBasePtr
(
input
),
GetBasePtr
(
input_grad
),
out_seq_data_desc
.
get
(),
GetBasePtr
(
output
),
GetBasePtr
(
output_grad
),
hx_desc
.
get
(),
GetBasePtr
(
init_h
),
GetBasePtr
(
last_h_grad
),
GetBasePtr
(
init_h_grad
),
cx_desc
.
get
(),
GetBasePtr
(
init_c
),
GetBasePtr
(
last_c_grad
),
GetBasePtr
(
init_c_grad
),
const_cast
<
void
*>
(
GetBasePtr
(
reserve_data
)),
reserve_data
->
numel
()
*
sizeof
(
T
));
void
*
dweightspace_ptr
=
dweightspace
.
mutable_data
(
ctx
.
GetPlace
());
auto
dw_x
=
parameter_lists_grad
[
0
][
0
];
auto
dw_h
=
parameter_lists_grad
[
0
][
1
];
auto
db_x
=
parameter_lists_grad
[
0
][
2
];
auto
db_h
=
parameter_lists_grad
[
0
][
3
];
auto
dactual_total_w_size
=
dw_x
.
second
+
dw_h
.
second
+
db_x
.
second
+
db_h
.
second
;
void
*
dw_x_ptr
=
dweightspace_ptr
;
void
*
dw_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
;
void
*
db_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
+
dw_h
.
second
;
void
*
db_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
+
dw_h
.
second
+
db_x
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
dw_x
.
first
,
weightspace
.
place
(),
dw_x_ptr
,
dw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dw_h
.
first
,
weightspace
.
place
(),
dw_h_ptr
,
dw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
db_x
.
first
,
weightspace
.
place
(),
db_x_ptr
,
db_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
db_h
.
first
,
weightspace
.
place
(),
db_h_ptr
,
db_h
.
second
,
stream
);
if
(
is_bidirec
)
{
auto
dbw_x
=
parameter_lists_grad
[
0
][
4
];
auto
dbw_h
=
parameter_lists_grad
[
0
][
5
];
auto
dbb_x
=
parameter_lists_grad
[
0
][
6
];
auto
dbb_h
=
parameter_lists_grad
[
0
][
7
];
void
*
dbw_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
;
void
*
dbw_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
;
void
*
dbb_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
+
dbw_h
.
second
;
void
*
dbb_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
+
dbw_h
.
second
+
dbb_x
.
second
;
dactual_total_w_size
+=
dbw_x
.
second
+
dbw_h
.
second
+
dbb_x
.
second
+
dbb_h
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
dbw_x
.
first
,
weightspace
.
place
(),
dbw_x_ptr
,
dbw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbw_h
.
first
,
weightspace
.
place
(),
dbw_h_ptr
,
dbw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbb_x
.
first
,
weightspace
.
place
(),
dbb_x_ptr
,
dbb_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbb_h
.
first
,
weightspace
.
place
(),
dbb_h_ptr
,
dbb_h
.
second
,
stream
);
}
dev_ctx
.
Wait
();
PADDLE_ENFORCE_EQ
(
weightspace_size
,
dactual_total_w_size
,
platform
::
errors
::
InvalidArgument
(
"The weightsize doesn't match"
" weightspace_size:%d, dactual_total_w_size:%d"
,
weightspace_size
,
dactual_total_w_size
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_MLU_KERNEL
(
rnn
,
ops
::
RNNMLUKernel
<
paddle
::
platform
::
MLUDeviceContext
,
float
>
);
REGISTER_OP_MLU_KERNEL
(
rnn_grad
,
ops
::
RNNMLUGradKernel
<
paddle
::
platform
::
MLUDeviceContext
,
float
>
);
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
浏览文件 @
3a59ede9
...
...
@@ -135,43 +135,50 @@ class TestRNNOp(OpTest):
def
test_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
no_check_set
=
[
'Reserve'
,
'DropoutState'
,
'State'
])
self
.
place
,
atol
=
1e-4
,
no_check_set
=
[
'Reserve'
,
'DropoutState'
,
'State'
])
def
set_attrs
(
self
):
pass
# def test_grad(self):
# if not self.is_test:
# var_name_list = self.get_weight_names()
# grad_check_list = ['Input', 'init_h', 'init_c']
# grad_check_list.extend(var_name_list)
# self.check_grad_with_place(self.place, set(grad_check_list),
# ['Out', 'last_hidden', 'last_cell'])
def
test_grad
(
self
):
if
not
self
.
is_test
and
self
.
sequence_length
is
None
:
# if not self.is_test:
var_name_list
=
self
.
get_weight_names
()
grad_check_list
=
[
'Input'
,
'init_h'
,
'init_c'
]
grad_check_list
.
extend
(
var_name_list
)
self
.
check_grad_with_place
(
self
.
place
,
set
(
grad_check_list
),
[
'Out'
,
'last_hidden'
,
'last_cell'
])
#
class TestRNNOp1(TestRNNOp):
class
TestRNNOp1
(
TestRNNOp
):
#
def set_attrs(self):
#
self.sequence_length = None
def
set_attrs
(
self
):
self
.
sequence_length
=
None
# class TestRNNOp2(TestRNNOp):
# def set_attrs(self):
# self.sequence_length = None
# self.is_bidirec = True
class
TestRNNOp2
(
TestRNNOp
):
# class TestRNNOp3(TestRNNOp):
def
set_attrs
(
self
):
self
.
sequence_length
=
None
self
.
is_bidirec
=
True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# class TestRNNOp4(TestRNNOp):
class
TestRNNOp3
(
TestRNNOp
):
def
set_attrs
(
self
):
self
.
is_test
=
True
self
.
sequence_length
=
None
class
TestRNNOp4
(
TestRNNOp
):
def
set_attrs
(
self
):
self
.
is_test
=
True
self
.
sequence_length
=
None
self
.
is_bidirec
=
True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# self.is_bidirec = True
#TODO(chenxiao): cnnl doesn't support num_layers > 1 case
# class TestRNNOp5(TestRNNOp):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录