From 6024488d3ae939fe11e43abc9921da042c708256 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 28 Jun 2021 14:16:30 +0800 Subject: [PATCH] [ROCM] fix RNN miopen as weight need to permuted, test=develop (#33733) * [ROCM] fix RNN miopen as weight need to permuted, test=develop * [ROCM] fix data share when is_test, test=develop * update, test=develop --- paddle/fluid/framework/tensor.cc | 43 ++++ paddle/fluid/framework/tensor.h | 16 ++ paddle/fluid/framework/tensor_test.cc | 126 ++++++++++++ .../fluid/memory/detail/system_allocator.cc | 2 +- paddle/fluid/operators/rnn_op.cu.cc | 191 +++++++++++++++--- .../fluid/tests/unittests/test_gru_rnn_op.py | 10 - .../fluid/tests/unittests/test_rnn_op.py | 10 - .../tests/unittests/test_simple_rnn_op.py | 10 +- 8 files changed, 355 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index b304a45be3c..4f6eb803d1c 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -135,6 +135,49 @@ Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const { } } +std::vector Tensor::Split(int64_t split_size, int64_t axis) const { + check_memory_size(); + PADDLE_ENFORCE_GE(dims_.size(), 0, + platform::errors::OutOfRange( + "split expects at least a 1-dimensional tensor")); + PADDLE_ENFORCE_GE( + split_size, 0, + platform::errors::OutOfRange( + "split expects split_size be non-negative, but got split_size is %d", + split_size)); + int64_t numel_size = dims_[axis]; + + int64_t num_splits = 1; + if (split_size != 0) { + num_splits = + std::max((numel_size + split_size - 1) / split_size, 1); + } + + std::vector splits(num_splits); + int64_t last_split_size = split_size - (split_size * num_splits - numel_size); + + for (int64_t i = 0; i < num_splits; ++i) { + int64_t length = i < num_splits - 1 ? split_size : last_split_size; + splits[i] = Slice(i * split_size, i * split_size + length); + } + return splits; +} + +std::vector Tensor::Chunk(int64_t chunks, int64_t axis) const { + check_memory_size(); + PADDLE_ENFORCE_GE(dims_.size(), 0, + platform::errors::OutOfRange( + "split expects at least a 1-dimensional tensor")); + PADDLE_ENFORCE_GE( + chunks, 0, + platform::errors::OutOfRange( + "chunks expects to be greater than 0, but got chunks is %d", chunks)); + + int64_t numel_size = dims_[axis]; + int64_t split_size = (numel_size + chunks - 1) / chunks; + return Split(split_size, axis); +} + Tensor& Tensor::Resize(const DDim& dims) { dims_ = dims; return *this; diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 0747321bcfa..539859c45c9 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -187,6 +187,22 @@ class Tensor { */ Tensor Slice(int64_t begin_idx, int64_t end_idx) const; + /** + * @brief Return a tensor list of the given tensor. + * + * @param[in] split_size The size of tensor to be split along axis. + * @param[in] axis The axis along which to split. + */ + std::vector Split(int64_t split_size, int64_t axis) const; + + /** + * @brief Return a tensor list of the given tensor. + * + * @param[in] chunks The number of tensor to be split along axis. + * @param[in] axis The axis along which to split. + */ + std::vector Chunk(int64_t chunks, int64_t axis) const; + const platform::Place& place() const { PADDLE_ENFORCE_NOT_NULL( holder_, diff --git a/paddle/fluid/framework/tensor_test.cc b/paddle/fluid/framework/tensor_test.cc index 101463756c0..71ff50c92ca 100644 --- a/paddle/fluid/framework/tensor_test.cc +++ b/paddle/fluid/framework/tensor_test.cc @@ -337,3 +337,129 @@ TEST(Tensor, FP16) { // Tensor holds the wrong type, it holds N6paddle8platform7float16E at // [/paddle/Paddle/paddle/fluid/framework/tensor_impl.h:43] } + +TEST(Tensor, Split) { + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 2}), + platform::CPUPlace()); + std::vector split_tensor_list = src_tensor.Split(2, 0); + ASSERT_EQ(split_tensor_list.size(), 3UL); + EXPECT_EQ(split_tensor_list[0].dims()[0], 2); + EXPECT_EQ(split_tensor_list[1].dims()[0], 2); + EXPECT_EQ(split_tensor_list[2].dims()[0], 2); + EXPECT_EQ(split_tensor_list[0].dims()[1], 2); + EXPECT_EQ(split_tensor_list[1].dims()[1], 2); + EXPECT_EQ(split_tensor_list[2].dims()[1], 2); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = reinterpret_cast( + src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace())); + EXPECT_EQ(src_data_address, src_mutable_data_address); + for (int i = 0; i < 3; ++i) { + uintptr_t split_data_address = + reinterpret_cast(split_tensor_list[i].data()); + uintptr_t split_mutable_data_address = + reinterpret_cast(split_tensor_list[i].mutable_data( + split_tensor_list[i].dims(), platform::CPUPlace())); + EXPECT_EQ(split_data_address, split_mutable_data_address); + EXPECT_EQ(src_data_address + 2 * 2 * i * sizeof(int), split_data_address); + } + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 4}), + platform::CUDAPlace(0)); + std::vector split_tensor_list = src_tensor.Split(2, 0); + ASSERT_EQ(split_tensor_list.size(), 3UL); + EXPECT_EQ(split_tensor_list[0].dims()[0], 2); + EXPECT_EQ(split_tensor_list[1].dims()[0], 2); + EXPECT_EQ(split_tensor_list[2].dims()[0], 2); + EXPECT_EQ(split_tensor_list[0].dims()[1], 4); + EXPECT_EQ(split_tensor_list[1].dims()[1], 4); + EXPECT_EQ(split_tensor_list[2].dims()[1], 4); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = + reinterpret_cast(src_tensor.mutable_data( + src_tensor.dims(), platform::CUDAPlace(0))); + EXPECT_EQ(src_data_address, src_mutable_data_address); + for (int i = 0; i < 3; ++i) { + uintptr_t split_data_address = + reinterpret_cast(split_tensor_list[i].data()); + uintptr_t split_mutable_data_address = + reinterpret_cast(split_tensor_list[i].mutable_data( + split_tensor_list[i].dims(), platform::CUDAPlace(0))); + EXPECT_EQ(split_data_address, split_mutable_data_address); + EXPECT_EQ(src_data_address + 2 * 4 * i * sizeof(double), + split_data_address); + } + } +#endif +} + +TEST(Tensor, Chunk) { + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 2}), + platform::CPUPlace()); + std::vector split_tensor_list = src_tensor.Chunk(3, 0); + ASSERT_EQ(split_tensor_list.size(), 3UL); + EXPECT_EQ(split_tensor_list[0].dims()[0], 2); + EXPECT_EQ(split_tensor_list[1].dims()[0], 2); + EXPECT_EQ(split_tensor_list[2].dims()[0], 2); + EXPECT_EQ(split_tensor_list[0].dims()[1], 2); + EXPECT_EQ(split_tensor_list[1].dims()[1], 2); + EXPECT_EQ(split_tensor_list[2].dims()[1], 2); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = reinterpret_cast( + src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace())); + for (int i = 0; i < 3; ++i) { + uintptr_t split_data_address = + reinterpret_cast(split_tensor_list[i].data()); + uintptr_t split_mutable_data_address = + reinterpret_cast(split_tensor_list[i].mutable_data( + split_tensor_list[i].dims(), platform::CPUPlace())); + EXPECT_EQ(src_data_address, src_mutable_data_address); + EXPECT_EQ(split_data_address, split_mutable_data_address); + EXPECT_EQ(src_data_address + 2 * 2 * i * sizeof(int), split_data_address); + } + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 4}), + platform::CUDAPlace(0)); + std::vector split_tensor_list = src_tensor.Chunk(3, 0); + ASSERT_EQ(split_tensor_list.size(), 3UL); + EXPECT_EQ(split_tensor_list[0].dims()[0], 2); + EXPECT_EQ(split_tensor_list[1].dims()[0], 2); + EXPECT_EQ(split_tensor_list[2].dims()[0], 2); + EXPECT_EQ(split_tensor_list[0].dims()[1], 4); + EXPECT_EQ(split_tensor_list[1].dims()[1], 4); + EXPECT_EQ(split_tensor_list[2].dims()[1], 4); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = + reinterpret_cast(src_tensor.mutable_data( + src_tensor.dims(), platform::CUDAPlace(0))); + EXPECT_EQ(src_data_address, src_mutable_data_address); + for (int i = 0; i < 3; ++i) { + uintptr_t split_data_address = + reinterpret_cast(split_tensor_list[i].data()); + uintptr_t split_mutable_data_address = + reinterpret_cast(split_tensor_list[i].mutable_data( + split_tensor_list[i].dims(), platform::CUDAPlace(0))); + EXPECT_EQ(split_data_address, split_mutable_data_address); + EXPECT_EQ(src_data_address + 2 * 4 * i * sizeof(double), + split_data_address); + } + } +#endif +} diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index d6dc303ebc7..9f39c3a823f 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -192,7 +192,7 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { void* p; // PINNED memory is visible to all CUDA contexts. #ifdef PADDLE_WITH_HIP - hipError_t result = hipHostMalloc(&p, size); + hipError_t result = hipHostMalloc(&p, size, hipHostMallocPortable); #else cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable); #endif diff --git a/paddle/fluid/operators/rnn_op.cu.cc b/paddle/fluid/operators/rnn_op.cu.cc index 2be59c62044..07329a9175e 100644 --- a/paddle/fluid/operators/rnn_op.cu.cc +++ b/paddle/fluid/operators/rnn_op.cu.cc @@ -29,15 +29,21 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +#ifdef PADDLE_WITH_HIP +using gpuRNNMode_t = miopenRNNMode_t; +using gpuDnnHandle_t = miopenHandle_t; +using gpuDnnDataType_t = miopenDataType_t; +#else +using gpuRNNMode_t = cudnnRNNMode_t; +using gpuDnnHandle_t = cudnnHandle_t; +using gpuDnnDataType_t = cudnnDataType_t; +#endif + class RNNDescriptors { public: RNNDescriptors(int seq_length, int batch_size, int input_size, int hidden_size, int num_layers, float dropout_prob, int seed, -#ifdef PADDLE_WITH_HIP - int weight_numel, miopenRNNMode_t mode, bool is_bidirec, -#else - int weight_numel, cudnnRNNMode_t mode, bool is_bidirec, -#endif + int weight_numel, gpuRNNMode_t mode, bool is_bidirec, bool is_test) : seq_length_(seq_length), batch_size_(batch_size), @@ -49,23 +55,14 @@ class RNNDescriptors { weight_numel_(weight_numel), mode_(mode), is_bidirec_(is_bidirec), - is_test_(is_test) { - } + is_test_(is_test) {} template -#ifdef PADDLE_WITH_HIP - void Create(const miopenHandle_t &handle, const platform::Place &place, -#else - void Create(const cudnnHandle_t &handle, const platform::Place &place, -#endif + void Create(const gpuDnnHandle_t &handle, const platform::Place &place, const std::vector &sequence_length, size_t *workspace_size, size_t *reserve_size, framework::Tensor *dropout_state) { int numDirections = is_bidirec_ ? 2 : 1; -#ifdef PADDLE_WITH_HIP - miopenDataType_t cudnn_type = platform::CudnnDataType::type; -#else - cudnnDataType_t cudnn_type = platform::CudnnDataType::type; -#endif + gpuDnnDataType_t cudnn_type = platform::CudnnDataType::type; // ------------------- cudnn x, y descriptors --------------------- std::vector dims_x = {batch_size_, input_size_, 1}; std::vector strides_x = {input_size_, 1, 1}; @@ -215,11 +212,7 @@ class RNNDescriptors { float dropout_prob_; int seed_; int weight_numel_; -#ifdef PADDLE_WITH_HIP - miopenRNNMode_t mode_; -#else - cudnnRNNMode_t mode_; -#endif + gpuRNNMode_t mode_; bool is_bidirec_; bool is_test_; #ifdef PADDLE_WITH_HIP @@ -296,6 +289,105 @@ void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream, } } +#ifdef PADDLE_WITH_HIP +template +void weight_list_to_tensor(const platform::Place &place, gpuStream_t stream, + const std::vector &tensor_list, + Tensor *weight_whole, const size_t offset = 0UL) { + size_t weight_offset = offset; + auto weight_data = weight_whole->data(); + + for (size_t i = 0; i < tensor_list.size(); ++i) { + const T *in_data = tensor_list[i].data(); + auto in_size = tensor_list[i].numel(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight_whole->place()), + weight_data + weight_offset, + BOOST_GET_CONST(platform::CUDAPlace, tensor_list[i].place()), + in_data, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + +template +void weight_to_permuted_tensor(const platform::Place &place, gpuStream_t stream, + std::vector *weight_list, + Tensor *weight_whole, + const gpuRNNMode_t rnn_mode, + const bool is_bidirec) { + if (is_bidirec) { + for (size_t i = 0; i < weight_list->size(); i += 4) { + auto tmp = (*weight_list)[i + 1]; + (*weight_list)[i + 1] = (*weight_list)[i + 2]; + (*weight_list)[i + 2] = tmp; + } + } + size_t weight_offset = 0; + for (size_t i = 0; i < weight_list->size(); ++i) { + if (rnn_mode == miopenLSTM) { + std::vector split_tensor = (*weight_list)[i]->Chunk(4, 0); + weight_list_to_tensor( + place, stream, + {split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]}, + weight_whole, weight_offset); + } else if (rnn_mode == miopenGRU) { + std::vector split_tensor = (*weight_list)[i]->Chunk(3, 0); + weight_list_to_tensor( + place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]}, + weight_whole, weight_offset); + } else { + weight_list_to_tensor(place, stream, {*(*weight_list)[i]}, + weight_whole, weight_offset); + } + weight_offset += (*weight_list)[i]->numel(); + } +} + +template +void tensor_to_permuted_weight(const platform::Place &place, gpuStream_t stream, + const Tensor &tensor, + std::vector *weight_grad_list, + const gpuRNNMode_t rnn_mode, + const bool is_bidirec) { + if (is_bidirec) { + for (size_t i = 0; i < weight_grad_list->size(); i += 4) { + auto tmp = (*weight_grad_list)[i + 1]; + (*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2]; + (*weight_grad_list)[i + 2] = tmp; + } + } + size_t weight_offset = 0; + for (size_t i = 0; i < weight_grad_list->size(); ++i) { + auto numel_size = (*weight_grad_list)[i]->numel(); + Tensor temp; + temp.mutable_data({numel_size}, place); + temp.ShareDataWith(tensor.Slice(weight_offset, weight_offset + numel_size)); + + if (rnn_mode == miopenLSTM) { + std::vector split_tensor = temp.Chunk(4, 0); + weight_list_to_tensor( + place, stream, + {split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]}, + (*weight_grad_list)[i]); + } else if (rnn_mode == miopenGRU) { + std::vector split_tensor = temp.Chunk(3, 0); + weight_list_to_tensor( + place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]}, + (*weight_grad_list)[i]); + } else { + weight_list_to_tensor(place, stream, {temp}, (*weight_grad_list)[i]); + } + weight_offset += numel_size; + } + if (is_bidirec) { + for (size_t i = 0; i < weight_grad_list->size(); i += 4) { + auto tmp = (*weight_grad_list)[i + 1]; + (*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2]; + (*weight_grad_list)[i + 2] = tmp; + } + } +} +#endif + template class RNNCudnnKernel : public framework::OpKernel { public: @@ -314,7 +406,7 @@ class RNNCudnnKernel : public framework::OpKernel { int num_layers = ctx.Attr("num_layers"); auto mode = ctx.Attr("mode"); #ifdef PADDLE_WITH_HIP - miopenRNNMode_t rnn_mode = miopenLSTM; + gpuRNNMode_t rnn_mode = miopenLSTM; if (mode == "LSTM") rnn_mode = miopenLSTM; else if (mode == "GRU") @@ -324,7 +416,7 @@ class RNNCudnnKernel : public framework::OpKernel { else if (mode == "RNN_TANH") rnn_mode = miopenRNNTANH; #else - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + gpuRNNMode_t rnn_mode = CUDNN_LSTM; if (mode == "LSTM") rnn_mode = CUDNN_LSTM; else if (mode == "GRU") @@ -373,6 +465,11 @@ class RNNCudnnKernel : public framework::OpKernel { } bool has_seq_length = ctx.HasInput("SequenceLength"); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_EQ(has_seq_length, false, + platform::errors::InvalidArgument( + "ROCm do not support SequenceLength yet.")); +#endif std::vector SequenceLength; if (has_seq_length) { auto *sequence_length = ctx.Input("SequenceLength"); @@ -400,14 +497,26 @@ class RNNCudnnKernel : public framework::OpKernel { [](int64_t num, const Tensor *t) { return num + t->numel(); }); bool continuous = is_continuous>(weight_list); +#ifdef PADDLE_WITH_HIP + // Need to permute weight, set continuous to false + continuous = false; +#endif if (!continuous) { LOG_FIRST_N(WARNING, 2) << "If the memory space of the Input WeightList is not continuous, " "less efficient calculation will be called. Please call " "flatten_parameters() to make the input memory continuous."; weight_whole.mutable_data({weight_numel}, place); +#ifdef PADDLE_WITH_HIP + // MIOPEN need to permute weight for miopenLSTM or miopenGRU + weight_to_permuted_tensor(place, stream, &weight_list, &weight_whole, + rnn_mode, is_bidirec); +#else weight_to_tensor(place, stream, weight_list, &weight_whole); +#endif w_data = weight_whole.data(); +#ifndef PADDLE_WITH_HIP + // MIOPEN need to permute weight, do not share with weight_grad if (is_test) { // maybe also reset small weights' ptr for training int offset = 0; for (size_t i = 0; i < weight_list.size(); ++i) { @@ -421,6 +530,7 @@ class RNNCudnnKernel : public framework::OpKernel { offset += len; } } +#endif } else { w_data = const_cast(weight_list[0]->data()); } @@ -486,11 +596,7 @@ class RNNCudnnKernel : public framework::OpKernel { } } -#ifdef PADDLE_WITH_HIP - void RNNInferece(const bool &has_seq_length, const miopenHandle_t &handle, -#else - void RNNInferece(const bool &has_seq_length, const cudnnHandle_t &handle, -#endif + void RNNInferece(const bool &has_seq_length, const gpuDnnHandle_t &handle, const int &seq_length, RNNDescriptors *rnn, const T *x_data, const T *init_h_data, const T *init_c_data, const T *w_data, T *out_data, T *last_h_data, T *last_c_data, @@ -607,9 +713,20 @@ class RNNGradCudnnKernel : public framework::OpKernel { Tensor weight_whole; T *weight_data = nullptr; +#ifdef PADDLE_WITH_HIP + // Need to permute weight, set continuous to false + continuous = false; +#endif + if (!continuous) { weight_whole.mutable_data({weight_numel}, place); +#ifdef PADDLE_WITH_HIP + // MIOPEN need to permute weight for miopenLSTM or miopenGRU + weight_to_permuted_tensor(place, stream, &weight_list, &weight_whole, + rnn_mode, is_bidirec); +#else weight_to_tensor(place, stream, weight_list, &weight_whole); +#endif weight_data = weight_whole.data(); } else { weight_data = const_cast(weight_list[0]->data()); @@ -621,6 +738,13 @@ class RNNGradCudnnKernel : public framework::OpKernel { zero(dev_ctx, &weight_grad, static_cast(0.0)); T *weight_grad_data = weight_grad.data(); +#ifdef PADDLE_WITH_HIP + // MIOPEN need to permute weight_grad_list, so do not share data with + // weight_grad + for (size_t i = 0; i < weight_grad_list.size(); ++i) { + weight_grad_list[i]->mutable_data(ctx.GetPlace()); + } +#else int offset = 0; for (size_t i = 0; i < weight_grad_list.size(); ++i) { size_t len = weight_grad_list[i]->numel(); @@ -631,6 +755,7 @@ class RNNGradCudnnKernel : public framework::OpKernel { .Resize(dim); offset += len; } +#endif Tensor input_grad_value; if (!in_grad) { @@ -672,6 +797,11 @@ class RNNGradCudnnKernel : public framework::OpKernel { } bool has_seq_length = ctx.HasInput("SequenceLength"); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_EQ(has_seq_length, false, + platform::errors::InvalidArgument( + "ROCm do not support SequenceLength yet.")); +#endif std::vector SequenceLength; if (has_seq_length) { auto *sequence_length = ctx.Input("SequenceLength"); @@ -731,6 +861,9 @@ class RNNGradCudnnKernel : public framework::OpKernel { rnn.weight_desc(), weight_grad_data, workspace_data_.data(), workspace_size, const_cast(reserve_data), reserve_size)); + // permute weight grad list from weight grad tensor + tensor_to_permuted_weight(place, stream, weight_grad, + &weight_grad_list, rnn_mode, is_bidirec); #else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), diff --git a/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py b/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py index 9f18ec9843d..77b88161d3a 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py @@ -92,16 +92,6 @@ class TestGRUOp(OpTest): self._get_places = rocm_rnn_get_place - if self.is_bidirec: - for i in range(0, len(flat_w), 4): - flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1] - - for i in range(len(flat_w)): - w = np.split(flat_w[i][1], 3, 0) - w = [w[1], w[0], w[2]] - w = np.concatenate(w) - flat_w[i] = (flat_w[i][0], w) - init_h = np.zeros((self.num_layers * self.direction_num, batch_size, self.hidden_size)).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_rnn_op.py b/python/paddle/fluid/tests/unittests/test_rnn_op.py index 22e07b0bc48..763ec3e7038 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_op.py @@ -95,16 +95,6 @@ class TestRNNOp(OpTest): self._get_places = rocm_rnn_get_place - if self.is_bidirec: - for i in range(0, len(flat_w), 4): - flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1] - - for i in range(len(flat_w)): - w = np.split(flat_w[i][1], 4, 0) - w = [w[0], w[1], w[3], w[2]] - w = np.concatenate(w) - flat_w[i] = (flat_w[i][0], w) - init_h = np.zeros((self.num_layers * self.direction_num, batch_size, hidden_size)).astype(self.dtype) init_c = np.zeros((self.num_layers * self.direction_num, batch_size, diff --git a/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py b/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py index 63688cbce24..d7e24b6308e 100644 --- a/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py @@ -19,6 +19,7 @@ import math from op_test import OpTest import paddle import paddle.fluid as fluid +import paddle.fluid.core as core import paddle.fluid.layers as layers import random import sys @@ -44,8 +45,10 @@ class TestSimpleRNNOp(OpTest): def setUp(self): self.op_type = "rnn" - self.dtype = np.float64 - self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) + self.dtype = "float32" if core.is_compiled_with_rocm() else "float64" + self.sequence_length = None if core.is_compiled_with_rocm( + ) else np.array( + [12, 11, 10, 9, 8], dtype=np.int32) self.num_layers = 1 self.is_bidirec = False self.is_test = False @@ -76,7 +79,8 @@ class TestSimpleRNNOp(OpTest): time_major=True, direction=direction, dropout=self.dropout, - nonlinearity=self.mode) + nonlinearity=self.mode, + dtype=self.dtype) flat_w = get_params_for_net(rnn1) -- GitLab