Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3a59ede9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a59ede9
编写于
7月 01, 2022
作者:
C
Chenxiao Niu
提交者:
GitHub
7月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] add rnn backward kernel. (#43969)
上级
9f397f16
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
522 addition
and
36 deletion
+522
-36
paddle/fluid/operators/mlu/mlu_baseop.cc
paddle/fluid/operators/mlu/mlu_baseop.cc
+82
-0
paddle/fluid/operators/mlu/mlu_baseop.h
paddle/fluid/operators/mlu/mlu_baseop.h
+24
-0
paddle/fluid/operators/rnn_op_mlu.cc
paddle/fluid/operators/rnn_op_mlu.cc
+385
-12
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
+31
-24
未找到文件。
paddle/fluid/operators/mlu/mlu_baseop.cc
浏览文件 @
3a59ede9
...
...
@@ -4616,6 +4616,88 @@ MLURNNDesc::~MLURNNDesc() {
reservespace_size
));
}
/* static */
void
MLUCnnl
::
RNNBackward
(
const
ExecutionContext
&
ctx
,
const
cnnlRNNDescriptor_t
rnn_desc
,
cnnlWgradMode_t
add_grad
,
const
int
dev_seq_lengths
[],
const
void
*
weight_param_ptr
,
void
*
dweight_param_ptr
,
size_t
weightspace_size
,
const
cnnlSeqDataDescriptor_t
x_desc
,
const
void
*
x
,
void
*
dx
,
const
cnnlSeqDataDescriptor_t
y_desc
,
const
void
*
y
,
const
void
*
dy
,
const
cnnlTensorDescriptor_t
hx_desc
,
const
void
*
hx
,
const
void
*
dhy
,
void
*
dhx
,
const
cnnlTensorDescriptor_t
cx_desc
,
const
void
*
cx
,
const
void
*
dcy
,
void
*
dcx
,
void
*
reservespace_ptr
,
size_t
reservespace_size
)
{
cnnlHandle_t
handle
=
GetHandleFromCTX
(
ctx
);
PADDLE_ENFORCE_NOT_NULL
(
rnn_desc
,
paddle
::
platform
::
errors
::
Fatal
(
"MLU RNNForward failed. rnn_desc initializing failed."
));
PADDLE_ENFORCE_NOT_NULL
(
x_desc
,
paddle
::
platform
::
errors
::
Fatal
(
"MLU RNNForward failed. x_desc initializing failed."
));
auto
&
dev_ctx
=
GetDevCtxFromCTX
(
ctx
);
size_t
workspace_size
;
Tensor
workspace
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetRNNTempSizes
(
handle
,
rnn_desc
,
x_desc
,
&
workspace_size
,
&
reservespace_size
));
workspace
=
ctx
.
AllocateTmpTensor
<
int8_t
,
MLUDeviceContext
>
(
{
static_cast
<
int64_t
>
(
workspace_size
)},
dev_ctx
);
void
*
workspace_ptr
=
workspace
.
mutable_data
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlRNNBackwardData
(
handle
,
rnn_desc
,
dev_seq_lengths
,
y_desc
,
y
,
dy
,
x_desc
,
dx
,
hx_desc
,
hx
,
dhy
,
dhx
,
cx_desc
,
cx
,
dcy
,
dcx
,
weight_param_ptr
,
weightspace_size
,
workspace_ptr
,
workspace_size
,
reservespace_ptr
,
reservespace_size
));
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlRNNBackwardWeights
(
handle
,
rnn_desc
,
add_grad
,
dev_seq_lengths
,
x_desc
,
x
,
hx_desc
,
hx
,
y_desc
,
y
,
dweight_param_ptr
,
weightspace_size
,
workspace_ptr
,
workspace_size
,
reservespace_ptr
,
reservespace_size
));
}
/* static */
void
MLUCnnl
::
Mask
(
const
ExecutionContext
&
ctx
,
cnnlMaskedOp_t
masked_mode
,
const
cnnlTensorDescriptor_t
input_desc
,
...
...
paddle/fluid/operators/mlu/mlu_baseop.h
浏览文件 @
3a59ede9
...
...
@@ -1924,6 +1924,30 @@ class MLUCnnl {
void
*
cy
,
void
*
reservespace_ptr
);
static
void
RNNBackward
(
const
ExecutionContext
&
ctx
,
const
cnnlRNNDescriptor_t
rnn_desc
,
cnnlWgradMode_t
add_grad
,
const
int
dev_seq_lengths
[],
const
void
*
weight_param_ptr
,
void
*
dweight_param_ptr
,
size_t
weightspace_size
,
const
cnnlSeqDataDescriptor_t
x_desc
,
const
void
*
x
,
void
*
dx
,
const
cnnlSeqDataDescriptor_t
y_desc
,
const
void
*
y
,
const
void
*
dy
,
const
cnnlTensorDescriptor_t
hx_desc
,
const
void
*
hx
,
const
void
*
dhy
,
void
*
dhx
,
const
cnnlTensorDescriptor_t
cx_desc
,
const
void
*
cx
,
const
void
*
dcy
,
void
*
dcx
,
void
*
reservespace_ptr
,
size_t
reservespace_size
);
static
void
Mask
(
const
ExecutionContext
&
ctx
,
cnnlMaskedOp_t
masked_mode
,
const
cnnlTensorDescriptor_t
input_desc
,
...
...
paddle/fluid/operators/rnn_op_mlu.cc
浏览文件 @
3a59ede9
...
...
@@ -28,7 +28,7 @@ void reset_parameter_vector(
const
std
::
vector
<
TensorType
>&
raw_params_vec
,
const
int
&
num_layers
,
const
bool
&
is_bidirec
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
const
T
*
,
size_t
>>>*
params_vec
)
{
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>*
params_vec
)
{
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
...
...
@@ -47,7 +47,8 @@ void reset_parameter_vector(
}
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
params_vec
->
at
(
i
)[
j
]
=
std
::
make_pair
(
raw_params_vec
[
tensor_idx
]
->
template
data
<
remove_cv_t
>(),
const_cast
<
T
*>
(
raw_params_vec
[
tensor_idx
]
->
template
data
<
remove_cv_t
>()),
raw_params_vec
[
tensor_idx
]
->
numel
()
*
sizeof
(
T
));
}
}
...
...
@@ -66,7 +67,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
// Output
auto
state
=
ctx
.
MultiOutput
<
Tensor
>
(
"State"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto
*
reserve_data
=
ctx
.
Output
<
Tensor
>
(
"Reserve"
);
// Attributes
const
int
&
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
...
...
@@ -79,14 +79,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
sequence_length
=
ctx
.
Input
<
Tensor
>
(
"SequenceLength"
);
}
// if (dropout_mask->IsInitialized()) {
// if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
// }
// dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
// auto& dev_ctx = ctx.template device_context<DeviceContext>();
// phi::funcs::SetConstant<platform::XPUDeviceContext, uint8_t> ones;
// ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
auto
init_h
=
pre_state
[
0
];
// -> hx
auto
init_c
=
pre_state
[
1
];
// -> cx
auto
last_h
=
state
[
0
];
...
...
@@ -143,7 +135,7 @@ class RNNMLUKernel : public framework::OpKernel<T> {
init_c
->
dims
()[
0
]));
// weightlist
std
::
vector
<
std
::
vector
<
std
::
pair
<
const
T
*
,
size_t
>>>
parameter_lists
;
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists
;
parameter_lists
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_list
,
num_layers
,
is_bidirec
,
&
parameter_lists
);
...
...
@@ -363,9 +355,390 @@ class RNNMLUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
RNNMLUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
stream
=
ctx
.
template
device_context
<
MLUDeviceContext
>().
stream
();
// get the tensor pointer for the input
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
pre_state
=
ctx
.
MultiInput
<
Tensor
>
(
"PreState"
);
auto
weight_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"WeightList"
);
auto
*
output
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
auto
*
reserve_data
=
ctx
.
Input
<
Tensor
>
(
"Reserve"
);
const
int
&
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
const
bool
&
is_bidirec
=
ctx
.
Attr
<
bool
>
(
"is_bidirec"
);
const
int
&
hidden_size
=
ctx
.
Attr
<
int
>
(
"hidden_size"
);
const
std
::
string
&
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
bool
has_seq_length
=
ctx
.
HasInput
(
"SequenceLength"
);
const
Tensor
*
sequence_length
=
nullptr
;
if
(
has_seq_length
)
{
sequence_length
=
ctx
.
Input
<
Tensor
>
(
"SequenceLength"
);
}
PADDLE_ENFORCE_EQ
(
mode
,
"LSTM"
,
platform
::
errors
::
InvalidArgument
(
"XPU only support LSTM mode now, current mode is %s"
,
mode
));
auto
init_h
=
pre_state
[
0
];
// -> hx
auto
init_c
=
pre_state
[
1
];
// -> cx
auto
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
state_grad
=
ctx
.
MultiInput
<
Tensor
>
(
framework
::
GradVarName
(
"State"
));
auto
last_h_grad
=
state_grad
[
0
];
// -> dhy
auto
last_c_grad
=
state_grad
[
1
];
// -> dcy
// get the tensor pointer for the output
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
weight_grad_list
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"WeightList"
));
auto
pre_state_grad
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"PreState"
));
Tensor
*
init_h_grad
=
nullptr
;
Tensor
*
init_c_grad
=
nullptr
;
if
(
pre_state_grad
.
size
()
>
0
)
{
// has gradient
init_h_grad
=
pre_state_grad
[
0
];
// -> dhx
init_c_grad
=
pre_state_grad
[
1
];
// -> dcx
}
// check shape
const
int
in_out_dim_num
=
input
->
dims
().
size
();
const
int
&
seq_len
=
input
->
dims
()[
0
];
const
int
&
batch_size
=
input
->
dims
()[
1
];
const
int
&
input_dim
=
input
->
dims
()[
2
];
const
int
&
direction_num
=
is_bidirec
?
2
:
1
;
int
in_dim_arr
[
in_out_dim_num
]
=
{
seq_len
,
batch_size
,
input_dim
};
int
out_dim_arr
[
in_out_dim_num
]
=
{
seq_len
,
batch_size
,
direction_num
*
hidden_size
};
int
proj_size
=
hidden_size
;
PADDLE_ENFORCE_EQ
(
num_layers
,
1
,
platform
::
errors
::
InvalidArgument
(
"MLU only support 1 num_layers, current num_layers is %s"
,
num_layers
));
PADDLE_ENFORCE_EQ
(
init_h
->
dims
()[
0
],
num_layers
*
direction_num
,
platform
::
errors
::
InvalidArgument
(
"The num_layers of in RNN layer must"
" be the same as first dim of init"
"hidden, but received num_layers:%d,"
" dim:%d"
,
num_layers
,
init_h
->
dims
()[
0
]));
PADDLE_ENFORCE_EQ
(
init_c
->
dims
()[
0
],
num_layers
*
direction_num
,
platform
::
errors
::
InvalidArgument
(
"The num_layers of in RNN layer must"
" be the same as first dim of cell state hidden, but received"
" num_layers:%d, dim:%d"
,
num_layers
,
init_c
->
dims
()[
0
]));
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists
;
parameter_lists
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_list
,
num_layers
,
is_bidirec
,
&
parameter_lists
);
for
(
unsigned
int
i
=
0
;
i
<
weight_grad_list
.
size
();
++
i
)
{
weight_grad_list
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
std
::
vector
<
std
::
vector
<
std
::
pair
<
T
*
,
size_t
>>>
parameter_lists_grad
;
parameter_lists_grad
.
resize
(
num_layers
);
reset_parameter_vector
(
weight_grad_list
,
num_layers
,
is_bidirec
,
&
parameter_lists_grad
);
// allocate the memory and initization the input_grad
input_grad
->
mutable_data
<
T
>
(
input
->
dims
(),
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
input_grad
);
Tensor
a
,
b
;
Tensor
*
dynamic_grad_pre_h
=
&
a
;
Tensor
*
dynamic_grad_pre_c
=
&
b
;
if
(
init_h_grad
)
{
init_h_grad
->
mutable_data
<
T
>
(
last_h_grad
->
dims
(),
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
init_h_grad
);
}
else
{
dynamic_grad_pre_h
->
Resize
(
last_h_grad
->
dims
());
dynamic_grad_pre_h
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
FillMLUTensorWithHostValue
(
ctx
,
static_cast
<
T
>
(
0.0
),
dynamic_grad_pre_h
);
init_h_grad
=
dynamic_grad_pre_h
;
}
if
(
init_c_grad
)
{
init_c_grad
->
mutable_data
<
T
>
(
last_c_grad
->
dims
(),
ctx
.
GetPlace
());
}
else
{
dynamic_grad_pre_c
->
Resize
(
last_h_grad
->
dims
());
dynamic_grad_pre_c
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
init_c_grad
=
dynamic_grad_pre_c
;
}
std
::
vector
<
int
>
seq_len_vec
(
batch_size
,
seq_len
);
if
(
has_seq_length
)
{
seq_len_vec
=
operators
::
GetDataFromTensor
(
sequence_length
);
}
cnnlDirectionMode_t
direction
=
is_bidirec
?
CNNL_RNN_BIDIRECTIONAL
:
CNNL_RNN_UNIDIRECTIONAL
;
MLUSeqDataDesc
input_seq_data_desc
(
CNNL_SEQDATA_TNC
,
ToCnnlDataType
(
input
->
dtype
()),
in_out_dim_num
,
in_dim_arr
,
static_cast
<
int
>
(
seq_len_vec
.
size
()),
seq_len_vec
.
data
(),
nullptr
);
MLUSeqDataDesc
out_seq_data_desc
(
CNNL_SEQDATA_TNC
,
ToCnnlDataType
(
input
->
dtype
()),
in_out_dim_num
,
out_dim_arr
,
static_cast
<
int
>
(
seq_len_vec
.
size
()),
seq_len_vec
.
data
(),
nullptr
);
MLUCnnlTensorDesc
hx_desc
(
*
init_h
);
MLUCnnlTensorDesc
cx_desc
(
*
init_c
);
MLURNNDesc
rnn_desc
(
CNNL_LSTM
,
CNNL_RNN_DOUBLE_BIAS
,
direction
,
CNNL_RNN_LINEAR_INPUT
,
ToCnnlDataType
(
input
->
dtype
()),
ToCnnlDataType
(
input
->
dtype
()),
input_dim
,
hidden_size
,
/*projection*/
proj_size
,
num_layers
,
nullptr
,
CNNL_RNN_PADDED_IO_DISABLED
);
rnn_desc
.
SetRNNMaskMode
(
CNNL_LSTM_MASK_ENABLED
);
// copy weight
size_t
weightspace_size
;
framework
::
Tensor
weightspace
,
dweightspace
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetRNNWeightSpaceSize
(
GetHandleFromCTX
(
ctx
),
rnn_desc
.
get
(),
&
weightspace_size
));
weightspace
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
{
static_cast
<
int64_t
>
(
weightspace_size
)},
dev_ctx
);
dweightspace
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
{
static_cast
<
int64_t
>
(
weightspace_size
)},
dev_ctx
);
void
*
weightspace_ptr
=
weightspace
.
mutable_data
(
ctx
.
GetPlace
());
auto
w_x
=
parameter_lists
[
0
][
0
];
auto
w_h
=
parameter_lists
[
0
][
1
];
auto
b_x
=
parameter_lists
[
0
][
2
];
auto
b_h
=
parameter_lists
[
0
][
3
];
auto
actual_total_w_size
=
w_x
.
second
+
w_h
.
second
+
b_x
.
second
+
b_h
.
second
;
void
*
w_x_ptr
=
weightspace_ptr
;
void
*
w_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
;
void
*
b_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
+
w_h
.
second
;
void
*
b_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
w_x
.
second
+
w_h
.
second
+
b_x
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
w_x_ptr
,
weightspace
.
place
(),
w_x
.
first
,
w_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
w_h_ptr
,
weightspace
.
place
(),
w_h
.
first
,
w_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
b_x_ptr
,
weightspace
.
place
(),
b_x
.
first
,
b_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
b_h_ptr
,
weightspace
.
place
(),
b_h
.
first
,
b_h
.
second
,
stream
);
if
(
is_bidirec
)
{
auto
bw_x
=
parameter_lists
[
0
][
4
];
auto
bw_h
=
parameter_lists
[
0
][
5
];
auto
bb_x
=
parameter_lists
[
0
][
6
];
auto
bb_h
=
parameter_lists
[
0
][
7
];
void
*
bw_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
;
void
*
bw_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
;
void
*
bb_x_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
+
bw_h
.
second
;
void
*
bb_h_ptr
=
static_cast
<
char
*>
(
weightspace_ptr
)
+
actual_total_w_size
+
bw_x
.
second
+
bw_h
.
second
+
bb_x
.
second
;
actual_total_w_size
+=
bw_x
.
second
+
bw_h
.
second
+
bb_x
.
second
+
bb_h
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
bw_x_ptr
,
weightspace
.
place
(),
bw_x
.
first
,
bw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bw_h_ptr
,
weightspace
.
place
(),
bw_h
.
first
,
bw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bb_x_ptr
,
weightspace
.
place
(),
bb_x
.
first
,
bb_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
bb_h_ptr
,
weightspace
.
place
(),
bb_h
.
first
,
bb_h
.
second
,
stream
);
}
dev_ctx
.
Wait
();
PADDLE_ENFORCE_EQ
(
weightspace_size
,
actual_total_w_size
,
platform
::
errors
::
InvalidArgument
(
"The weightsize doesn't match"
" weightspace_size:%d, actual_total_w_size:%d"
,
weightspace_size
,
actual_total_w_size
));
MLUCnnl
::
RNNBackward
(
ctx
,
rnn_desc
.
get
(),
CNNL_WGRAD_MODE_SET
,
seq_len_vec
.
data
(),
GetBasePtr
(
&
weightspace
),
GetBasePtr
(
&
dweightspace
),
weightspace
.
numel
()
*
sizeof
(
T
),
input_seq_data_desc
.
get
(),
GetBasePtr
(
input
),
GetBasePtr
(
input_grad
),
out_seq_data_desc
.
get
(),
GetBasePtr
(
output
),
GetBasePtr
(
output_grad
),
hx_desc
.
get
(),
GetBasePtr
(
init_h
),
GetBasePtr
(
last_h_grad
),
GetBasePtr
(
init_h_grad
),
cx_desc
.
get
(),
GetBasePtr
(
init_c
),
GetBasePtr
(
last_c_grad
),
GetBasePtr
(
init_c_grad
),
const_cast
<
void
*>
(
GetBasePtr
(
reserve_data
)),
reserve_data
->
numel
()
*
sizeof
(
T
));
void
*
dweightspace_ptr
=
dweightspace
.
mutable_data
(
ctx
.
GetPlace
());
auto
dw_x
=
parameter_lists_grad
[
0
][
0
];
auto
dw_h
=
parameter_lists_grad
[
0
][
1
];
auto
db_x
=
parameter_lists_grad
[
0
][
2
];
auto
db_h
=
parameter_lists_grad
[
0
][
3
];
auto
dactual_total_w_size
=
dw_x
.
second
+
dw_h
.
second
+
db_x
.
second
+
db_h
.
second
;
void
*
dw_x_ptr
=
dweightspace_ptr
;
void
*
dw_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
;
void
*
db_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
+
dw_h
.
second
;
void
*
db_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dw_x
.
second
+
dw_h
.
second
+
db_x
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
dw_x
.
first
,
weightspace
.
place
(),
dw_x_ptr
,
dw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dw_h
.
first
,
weightspace
.
place
(),
dw_h_ptr
,
dw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
db_x
.
first
,
weightspace
.
place
(),
db_x_ptr
,
db_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
db_h
.
first
,
weightspace
.
place
(),
db_h_ptr
,
db_h
.
second
,
stream
);
if
(
is_bidirec
)
{
auto
dbw_x
=
parameter_lists_grad
[
0
][
4
];
auto
dbw_h
=
parameter_lists_grad
[
0
][
5
];
auto
dbb_x
=
parameter_lists_grad
[
0
][
6
];
auto
dbb_h
=
parameter_lists_grad
[
0
][
7
];
void
*
dbw_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
;
void
*
dbw_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
;
void
*
dbb_x_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
+
dbw_h
.
second
;
void
*
dbb_h_ptr
=
static_cast
<
char
*>
(
dweightspace_ptr
)
+
dactual_total_w_size
+
dbw_x
.
second
+
dbw_h
.
second
+
dbb_x
.
second
;
dactual_total_w_size
+=
dbw_x
.
second
+
dbw_h
.
second
+
dbb_x
.
second
+
dbb_h
.
second
;
memory
::
Copy
(
weightspace
.
place
(),
dbw_x
.
first
,
weightspace
.
place
(),
dbw_x_ptr
,
dbw_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbw_h
.
first
,
weightspace
.
place
(),
dbw_h_ptr
,
dbw_h
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbb_x
.
first
,
weightspace
.
place
(),
dbb_x_ptr
,
dbb_x
.
second
,
stream
);
memory
::
Copy
(
weightspace
.
place
(),
dbb_h
.
first
,
weightspace
.
place
(),
dbb_h_ptr
,
dbb_h
.
second
,
stream
);
}
dev_ctx
.
Wait
();
PADDLE_ENFORCE_EQ
(
weightspace_size
,
dactual_total_w_size
,
platform
::
errors
::
InvalidArgument
(
"The weightsize doesn't match"
" weightspace_size:%d, dactual_total_w_size:%d"
,
weightspace_size
,
dactual_total_w_size
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_MLU_KERNEL
(
rnn
,
ops
::
RNNMLUKernel
<
paddle
::
platform
::
MLUDeviceContext
,
float
>
);
REGISTER_OP_MLU_KERNEL
(
rnn_grad
,
ops
::
RNNMLUGradKernel
<
paddle
::
platform
::
MLUDeviceContext
,
float
>
);
python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py
浏览文件 @
3a59ede9
...
...
@@ -135,43 +135,50 @@ class TestRNNOp(OpTest):
def
test_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
no_check_set
=
[
'Reserve'
,
'DropoutState'
,
'State'
])
self
.
place
,
atol
=
1e-4
,
no_check_set
=
[
'Reserve'
,
'DropoutState'
,
'State'
])
def
set_attrs
(
self
):
pass
# def test_grad(self):
# if not self.is_test:
# var_name_list = self.get_weight_names()
# grad_check_list = ['Input', 'init_h', 'init_c']
# grad_check_list.extend(var_name_list)
# self.check_grad_with_place(self.place, set(grad_check_list),
# ['Out', 'last_hidden', 'last_cell'])
def
test_grad
(
self
):
if
not
self
.
is_test
and
self
.
sequence_length
is
None
:
# if not self.is_test:
var_name_list
=
self
.
get_weight_names
()
grad_check_list
=
[
'Input'
,
'init_h'
,
'init_c'
]
grad_check_list
.
extend
(
var_name_list
)
self
.
check_grad_with_place
(
self
.
place
,
set
(
grad_check_list
),
[
'Out'
,
'last_hidden'
,
'last_cell'
])
#
class TestRNNOp1(TestRNNOp):
class
TestRNNOp1
(
TestRNNOp
):
#
def set_attrs(self):
#
self.sequence_length = None
def
set_attrs
(
self
):
self
.
sequence_length
=
None
# class TestRNNOp2(TestRNNOp):
# def set_attrs(self):
# self.sequence_length = None
# self.is_bidirec = True
class
TestRNNOp2
(
TestRNNOp
):
# class TestRNNOp3(TestRNNOp):
def
set_attrs
(
self
):
self
.
sequence_length
=
None
self
.
is_bidirec
=
True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# class TestRNNOp4(TestRNNOp):
class
TestRNNOp3
(
TestRNNOp
):
def
set_attrs
(
self
):
self
.
is_test
=
True
self
.
sequence_length
=
None
class
TestRNNOp4
(
TestRNNOp
):
def
set_attrs
(
self
):
self
.
is_test
=
True
self
.
sequence_length
=
None
self
.
is_bidirec
=
True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# self.is_bidirec = True
#TODO(chenxiao): cnnl doesn't support num_layers > 1 case
# class TestRNNOp5(TestRNNOp):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录