未验证 提交 6024488d 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix RNN miopen as weight need to permuted, test=develop (#33733)

* [ROCM] fix RNN miopen as weight need to permuted, test=develop

* [ROCM] fix data share when is_test, test=develop

* update, test=develop
上级 b538c6d7
......@@ -135,6 +135,49 @@ Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
}
}
std::vector<Tensor> Tensor::Split(int64_t split_size, int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(dims_.size(), 0,
platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
split_size, 0,
platform::errors::OutOfRange(
"split expects split_size be non-negative, but got split_size is %d",
split_size));
int64_t numel_size = dims_[axis];
int64_t num_splits = 1;
if (split_size != 0) {
num_splits =
std::max<int64_t>((numel_size + split_size - 1) / split_size, 1);
}
std::vector<Tensor> splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - numel_size);
for (int64_t i = 0; i < num_splits; ++i) {
int64_t length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = Slice(i * split_size, i * split_size + length);
}
return splits;
}
std::vector<Tensor> Tensor::Chunk(int64_t chunks, int64_t axis) const {
check_memory_size();
PADDLE_ENFORCE_GE(dims_.size(), 0,
platform::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
chunks, 0,
platform::errors::OutOfRange(
"chunks expects to be greater than 0, but got chunks is %d", chunks));
int64_t numel_size = dims_[axis];
int64_t split_size = (numel_size + chunks - 1) / chunks;
return Split(split_size, axis);
}
Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
return *this;
......
......@@ -187,6 +187,22 @@ class Tensor {
*/
Tensor Slice(int64_t begin_idx, int64_t end_idx) const;
/**
* @brief Return a tensor list of the given tensor.
*
* @param[in] split_size The size of tensor to be split along axis.
* @param[in] axis The axis along which to split.
*/
std::vector<Tensor> Split(int64_t split_size, int64_t axis) const;
/**
* @brief Return a tensor list of the given tensor.
*
* @param[in] chunks The number of tensor to be split along axis.
* @param[in] axis The axis along which to split.
*/
std::vector<Tensor> Chunk(int64_t chunks, int64_t axis) const;
const platform::Place& place() const {
PADDLE_ENFORCE_NOT_NULL(
holder_,
......
......@@ -337,3 +337,129 @@ TEST(Tensor, FP16) {
// Tensor holds the wrong type, it holds N6paddle8platform7float16E at
// [/paddle/Paddle/paddle/fluid/framework/tensor_impl.h:43]
}
TEST(Tensor, Split) {
{
framework::Tensor src_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({6, 2}),
platform::CPUPlace());
std::vector<framework::Tensor> split_tensor_list = src_tensor.Split(2, 0);
ASSERT_EQ(split_tensor_list.size(), 3UL);
EXPECT_EQ(split_tensor_list[0].dims()[0], 2);
EXPECT_EQ(split_tensor_list[1].dims()[0], 2);
EXPECT_EQ(split_tensor_list[2].dims()[0], 2);
EXPECT_EQ(split_tensor_list[0].dims()[1], 2);
EXPECT_EQ(split_tensor_list[1].dims()[1], 2);
EXPECT_EQ(split_tensor_list[2].dims()[1], 2);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
for (int i = 0; i < 3; ++i) {
uintptr_t split_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].data<int>());
uintptr_t split_mutable_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].mutable_data<int>(
split_tensor_list[i].dims(), platform::CPUPlace()));
EXPECT_EQ(split_data_address, split_mutable_data_address);
EXPECT_EQ(src_data_address + 2 * 2 * i * sizeof(int), split_data_address);
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
{
framework::Tensor src_tensor;
src_tensor.mutable_data<double>(framework::make_ddim({6, 4}),
platform::CUDAPlace(0));
std::vector<framework::Tensor> split_tensor_list = src_tensor.Split(2, 0);
ASSERT_EQ(split_tensor_list.size(), 3UL);
EXPECT_EQ(split_tensor_list[0].dims()[0], 2);
EXPECT_EQ(split_tensor_list[1].dims()[0], 2);
EXPECT_EQ(split_tensor_list[2].dims()[0], 2);
EXPECT_EQ(split_tensor_list[0].dims()[1], 4);
EXPECT_EQ(split_tensor_list[1].dims()[1], 4);
EXPECT_EQ(split_tensor_list[2].dims()[1], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address =
reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
src_tensor.dims(), platform::CUDAPlace(0)));
EXPECT_EQ(src_data_address, src_mutable_data_address);
for (int i = 0; i < 3; ++i) {
uintptr_t split_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].data<double>());
uintptr_t split_mutable_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].mutable_data<double>(
split_tensor_list[i].dims(), platform::CUDAPlace(0)));
EXPECT_EQ(split_data_address, split_mutable_data_address);
EXPECT_EQ(src_data_address + 2 * 4 * i * sizeof(double),
split_data_address);
}
}
#endif
}
TEST(Tensor, Chunk) {
{
framework::Tensor src_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({6, 2}),
platform::CPUPlace());
std::vector<framework::Tensor> split_tensor_list = src_tensor.Chunk(3, 0);
ASSERT_EQ(split_tensor_list.size(), 3UL);
EXPECT_EQ(split_tensor_list[0].dims()[0], 2);
EXPECT_EQ(split_tensor_list[1].dims()[0], 2);
EXPECT_EQ(split_tensor_list[2].dims()[0], 2);
EXPECT_EQ(split_tensor_list[0].dims()[1], 2);
EXPECT_EQ(split_tensor_list[1].dims()[1], 2);
EXPECT_EQ(split_tensor_list[2].dims()[1], 2);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), platform::CPUPlace()));
for (int i = 0; i < 3; ++i) {
uintptr_t split_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].data<int>());
uintptr_t split_mutable_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].mutable_data<int>(
split_tensor_list[i].dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(split_data_address, split_mutable_data_address);
EXPECT_EQ(src_data_address + 2 * 2 * i * sizeof(int), split_data_address);
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
{
framework::Tensor src_tensor;
src_tensor.mutable_data<double>(framework::make_ddim({6, 4}),
platform::CUDAPlace(0));
std::vector<framework::Tensor> split_tensor_list = src_tensor.Chunk(3, 0);
ASSERT_EQ(split_tensor_list.size(), 3UL);
EXPECT_EQ(split_tensor_list[0].dims()[0], 2);
EXPECT_EQ(split_tensor_list[1].dims()[0], 2);
EXPECT_EQ(split_tensor_list[2].dims()[0], 2);
EXPECT_EQ(split_tensor_list[0].dims()[1], 4);
EXPECT_EQ(split_tensor_list[1].dims()[1], 4);
EXPECT_EQ(split_tensor_list[2].dims()[1], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address =
reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
src_tensor.dims(), platform::CUDAPlace(0)));
EXPECT_EQ(src_data_address, src_mutable_data_address);
for (int i = 0; i < 3; ++i) {
uintptr_t split_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].data<double>());
uintptr_t split_mutable_data_address =
reinterpret_cast<uintptr_t>(split_tensor_list[i].mutable_data<double>(
split_tensor_list[i].dims(), platform::CUDAPlace(0)));
EXPECT_EQ(split_data_address, split_mutable_data_address);
EXPECT_EQ(src_data_address + 2 * 4 * i * sizeof(double),
split_data_address);
}
}
#endif
}
......@@ -192,7 +192,7 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
void* p;
// PINNED memory is visible to all CUDA contexts.
#ifdef PADDLE_WITH_HIP
hipError_t result = hipHostMalloc(&p, size);
hipError_t result = hipHostMalloc(&p, size, hipHostMallocPortable);
#else
cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
#endif
......
......@@ -29,15 +29,21 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_HIP
using gpuRNNMode_t = miopenRNNMode_t;
using gpuDnnHandle_t = miopenHandle_t;
using gpuDnnDataType_t = miopenDataType_t;
#else
using gpuRNNMode_t = cudnnRNNMode_t;
using gpuDnnHandle_t = cudnnHandle_t;
using gpuDnnDataType_t = cudnnDataType_t;
#endif
class RNNDescriptors {
public:
RNNDescriptors(int seq_length, int batch_size, int input_size,
int hidden_size, int num_layers, float dropout_prob, int seed,
#ifdef PADDLE_WITH_HIP
int weight_numel, miopenRNNMode_t mode, bool is_bidirec,
#else
int weight_numel, cudnnRNNMode_t mode, bool is_bidirec,
#endif
int weight_numel, gpuRNNMode_t mode, bool is_bidirec,
bool is_test)
: seq_length_(seq_length),
batch_size_(batch_size),
......@@ -49,23 +55,14 @@ class RNNDescriptors {
weight_numel_(weight_numel),
mode_(mode),
is_bidirec_(is_bidirec),
is_test_(is_test) {
}
is_test_(is_test) {}
template <typename T>
#ifdef PADDLE_WITH_HIP
void Create(const miopenHandle_t &handle, const platform::Place &place,
#else
void Create(const cudnnHandle_t &handle, const platform::Place &place,
#endif
void Create(const gpuDnnHandle_t &handle, const platform::Place &place,
const std::vector<int> &sequence_length, size_t *workspace_size,
size_t *reserve_size, framework::Tensor *dropout_state) {
int numDirections = is_bidirec_ ? 2 : 1;
#ifdef PADDLE_WITH_HIP
miopenDataType_t cudnn_type = platform::CudnnDataType<T>::type;
#else
cudnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
#endif
gpuDnnDataType_t cudnn_type = platform::CudnnDataType<T>::type;
// ------------------- cudnn x, y descriptors ---------------------
std::vector<int> dims_x = {batch_size_, input_size_, 1};
std::vector<int> strides_x = {input_size_, 1, 1};
......@@ -215,11 +212,7 @@ class RNNDescriptors {
float dropout_prob_;
int seed_;
int weight_numel_;
#ifdef PADDLE_WITH_HIP
miopenRNNMode_t mode_;
#else
cudnnRNNMode_t mode_;
#endif
gpuRNNMode_t mode_;
bool is_bidirec_;
bool is_test_;
#ifdef PADDLE_WITH_HIP
......@@ -296,6 +289,105 @@ void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
}
}
#ifdef PADDLE_WITH_HIP
template <typename T>
void weight_list_to_tensor(const platform::Place &place, gpuStream_t stream,
const std::vector<Tensor> &tensor_list,
Tensor *weight_whole, const size_t offset = 0UL) {
size_t weight_offset = offset;
auto weight_data = weight_whole->data<T>();
for (size_t i = 0; i < tensor_list.size(); ++i) {
const T *in_data = tensor_list[i].data<T>();
auto in_size = tensor_list[i].numel();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight_whole->place()),
weight_data + weight_offset,
BOOST_GET_CONST(platform::CUDAPlace, tensor_list[i].place()),
in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
template <typename T>
void weight_to_permuted_tensor(const platform::Place &place, gpuStream_t stream,
std::vector<const Tensor *> *weight_list,
Tensor *weight_whole,
const gpuRNNMode_t rnn_mode,
const bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_list->size(); i += 4) {
auto tmp = (*weight_list)[i + 1];
(*weight_list)[i + 1] = (*weight_list)[i + 2];
(*weight_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_list->size(); ++i) {
if (rnn_mode == miopenLSTM) {
std::vector<Tensor> split_tensor = (*weight_list)[i]->Chunk(4, 0);
weight_list_to_tensor<T>(
place, stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
weight_whole, weight_offset);
} else if (rnn_mode == miopenGRU) {
std::vector<Tensor> split_tensor = (*weight_list)[i]->Chunk(3, 0);
weight_list_to_tensor<T>(
place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]},
weight_whole, weight_offset);
} else {
weight_list_to_tensor<T>(place, stream, {*(*weight_list)[i]},
weight_whole, weight_offset);
}
weight_offset += (*weight_list)[i]->numel();
}
}
template <typename T>
void tensor_to_permuted_weight(const platform::Place &place, gpuStream_t stream,
const Tensor &tensor,
std::vector<Tensor *> *weight_grad_list,
const gpuRNNMode_t rnn_mode,
const bool is_bidirec) {
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
size_t weight_offset = 0;
for (size_t i = 0; i < weight_grad_list->size(); ++i) {
auto numel_size = (*weight_grad_list)[i]->numel();
Tensor temp;
temp.mutable_data<T>({numel_size}, place);
temp.ShareDataWith(tensor.Slice(weight_offset, weight_offset + numel_size));
if (rnn_mode == miopenLSTM) {
std::vector<Tensor> split_tensor = temp.Chunk(4, 0);
weight_list_to_tensor<T>(
place, stream,
{split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
(*weight_grad_list)[i]);
} else if (rnn_mode == miopenGRU) {
std::vector<Tensor> split_tensor = temp.Chunk(3, 0);
weight_list_to_tensor<T>(
place, stream, {split_tensor[1], split_tensor[0], split_tensor[2]},
(*weight_grad_list)[i]);
} else {
weight_list_to_tensor<T>(place, stream, {temp}, (*weight_grad_list)[i]);
}
weight_offset += numel_size;
}
if (is_bidirec) {
for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
auto tmp = (*weight_grad_list)[i + 1];
(*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
(*weight_grad_list)[i + 2] = tmp;
}
}
}
#endif
template <typename T>
class RNNCudnnKernel : public framework::OpKernel<T> {
public:
......@@ -314,7 +406,7 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
int num_layers = ctx.Attr<int>("num_layers");
auto mode = ctx.Attr<std::string>("mode");
#ifdef PADDLE_WITH_HIP
miopenRNNMode_t rnn_mode = miopenLSTM;
gpuRNNMode_t rnn_mode = miopenLSTM;
if (mode == "LSTM")
rnn_mode = miopenLSTM;
else if (mode == "GRU")
......@@ -324,7 +416,7 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
else if (mode == "RNN_TANH")
rnn_mode = miopenRNNTANH;
#else
cudnnRNNMode_t rnn_mode = CUDNN_LSTM;
gpuRNNMode_t rnn_mode = CUDNN_LSTM;
if (mode == "LSTM")
rnn_mode = CUDNN_LSTM;
else if (mode == "GRU")
......@@ -373,6 +465,11 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
}
bool has_seq_length = ctx.HasInput("SequenceLength");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(has_seq_length, false,
platform::errors::InvalidArgument(
"ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
......@@ -400,14 +497,26 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
[](int64_t num, const Tensor *t) { return num + t->numel(); });
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not continuous, "
"less efficient calculation will be called. Please call "
"flatten_parameters() to make the input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
weight_to_permuted_tensor<T>(place, stream, &weight_list, &weight_whole,
rnn_mode, is_bidirec);
#else
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
#endif
w_data = weight_whole.data<T>();
#ifndef PADDLE_WITH_HIP
// MIOPEN need to permute weight, do not share with weight_grad
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
......@@ -421,6 +530,7 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
offset += len;
}
}
#endif
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
......@@ -486,11 +596,7 @@ class RNNCudnnKernel : public framework::OpKernel<T> {
}
}
#ifdef PADDLE_WITH_HIP
void RNNInferece(const bool &has_seq_length, const miopenHandle_t &handle,
#else
void RNNInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
#endif
void RNNInferece(const bool &has_seq_length, const gpuDnnHandle_t &handle,
const int &seq_length, RNNDescriptors *rnn, const T *x_data,
const T *init_h_data, const T *init_c_data, const T *w_data,
T *out_data, T *last_h_data, T *last_c_data,
......@@ -607,9 +713,20 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
Tensor weight_whole;
T *weight_data = nullptr;
#ifdef PADDLE_WITH_HIP
// Need to permute weight, set continuous to false
continuous = false;
#endif
if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight for miopenLSTM or miopenGRU
weight_to_permuted_tensor<T>(place, stream, &weight_list, &weight_whole,
rnn_mode, is_bidirec);
#else
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
#endif
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
......@@ -621,6 +738,13 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();
#ifdef PADDLE_WITH_HIP
// MIOPEN need to permute weight_grad_list, so do not share data with
// weight_grad
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(ctx.GetPlace());
}
#else
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
......@@ -631,6 +755,7 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
.Resize(dim);
offset += len;
}
#endif
Tensor input_grad_value;
if (!in_grad) {
......@@ -672,6 +797,11 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
}
bool has_seq_length = ctx.HasInput("SequenceLength");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(has_seq_length, false,
platform::errors::InvalidArgument(
"ROCm do not support SequenceLength yet."));
#endif
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
......@@ -731,6 +861,9 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
rnn.weight_desc(), weight_grad_data,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
// permute weight grad list from weight grad tensor
tensor_to_permuted_weight<T>(place, stream, weight_grad,
&weight_grad_list, rnn_mode, is_bidirec);
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
......
......@@ -92,16 +92,6 @@ class TestGRUOp(OpTest):
self._get_places = rocm_rnn_get_place
if self.is_bidirec:
for i in range(0, len(flat_w), 4):
flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1]
for i in range(len(flat_w)):
w = np.split(flat_w[i][1], 3, 0)
w = [w[1], w[0], w[2]]
w = np.concatenate(w)
flat_w[i] = (flat_w[i][0], w)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
self.hidden_size)).astype(self.dtype)
......
......@@ -95,16 +95,6 @@ class TestRNNOp(OpTest):
self._get_places = rocm_rnn_get_place
if self.is_bidirec:
for i in range(0, len(flat_w), 4):
flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1]
for i in range(len(flat_w)):
w = np.split(flat_w[i][1], 4, 0)
w = [w[0], w[1], w[3], w[2]]
w = np.concatenate(w)
flat_w[i] = (flat_w[i][0], w)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers * self.direction_num, batch_size,
......
......@@ -19,6 +19,7 @@ import math
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import random
import sys
......@@ -44,8 +45,10 @@ class TestSimpleRNNOp(OpTest):
def setUp(self):
self.op_type = "rnn"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
self.sequence_length = None if core.is_compiled_with_rocm(
) else np.array(
[12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.is_test = False
......@@ -76,7 +79,8 @@ class TestSimpleRNNOp(OpTest):
time_major=True,
direction=direction,
dropout=self.dropout,
nonlinearity=self.mode)
nonlinearity=self.mode,
dtype=self.dtype)
flat_w = get_params_for_net(rnn1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册