未验证 提交 c9e1d9dc 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix test_rnn_op (#31735)

上级 1c67cf0c
...@@ -117,10 +117,11 @@ class RNNDescriptors { ...@@ -117,10 +117,11 @@ class RNNDescriptors {
// ------------------- cudnn rnn descriptors --------------------- // ------------------- cudnn rnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor_V2(
rnn_desc_.desc(), hidden_size_, num_layers_, miopenRNNlinear, rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(),
miopenRNNlinear,
is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, mode_, is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, mode_,
miopenRNNNoBias, miopenRNNdefault, cudnn_type)); miopenRNNwithBias, miopenRNNdefault, cudnn_type));
#elif CUDNN_VERSION >= 6000 #elif CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_.desc(), hidden_size_, num_layers_, handle, rnn_desc_.desc(), hidden_size_, num_layers_,
......
...@@ -125,6 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -125,6 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(miopenCreateRNNDescriptor); \ __macro(miopenCreateRNNDescriptor); \
__macro(miopenDestroyRNNDescriptor); \ __macro(miopenDestroyRNNDescriptor); \
__macro(miopenSetRNNDescriptor); \ __macro(miopenSetRNNDescriptor); \
__macro(miopenSetRNNDescriptor_V2); \
__macro(miopenGetRNNParamsSize); \ __macro(miopenGetRNNParamsSize); \
__macro(miopenGetRNNWorkspaceSize); \ __macro(miopenGetRNNWorkspaceSize); \
__macro(miopenGetRNNTrainingReserveSize); \ __macro(miopenGetRNNTrainingReserveSize); \
......
...@@ -47,8 +47,10 @@ class TestRNNOp(OpTest): ...@@ -47,8 +47,10 @@ class TestRNNOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "rnn" self.op_type = "rnn"
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) self.sequence_length = None if core.is_compiled_with_rocm(
) else np.array(
[12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1 self.num_layers = 1
self.is_bidirec = False self.is_bidirec = False
self.mode = "LSTM" self.mode = "LSTM"
...@@ -78,12 +80,31 @@ class TestRNNOp(OpTest): ...@@ -78,12 +80,31 @@ class TestRNNOp(OpTest):
num_layers=self.num_layers, num_layers=self.num_layers,
time_major=True, time_major=True,
direction=direction, direction=direction,
dropout=self.dropout) dropout=self.dropout,
dtype=self.dtype)
flat_w = get_params_for_net(rnn1) flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1( output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length) input, sequence_length=self.sequence_length)
if core.is_compiled_with_rocm():
def rocm_rnn_get_place():
places = [core.CUDAPlace(0)]
return places
self._get_places = rocm_rnn_get_place
if self.is_bidirec:
for i in range(0, len(flat_w), 4):
flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1]
for i in range(len(flat_w)):
w = np.split(flat_w[i][1], 4, 0)
w = [w[0], w[1], w[3], w[2]]
w = np.concatenate(w)
flat_w[i] = (flat_w[i][0], w)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size, init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype) hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers * self.direction_num, batch_size, init_c = np.zeros((self.num_layers * self.direction_num, batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册