未验证 提交 bd2a4e23 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix some unittests (#32129)

* [ROCM] fix test_gru_rnn_op

* [ROCM] fix test_expand_op

* [ROCM] fix test_cross_entropy_loss

* [ROCM] fix test_conv_nn_grad

* [ROCM] fix test_bilinear_tensor_product_op

* [ROCM] fix elementwise_op_function

* [ROCM] fix test_lstm_cudnn_op

* [ROCM] fix test_gpu_package_without_gpu_device

* [ROCM] fix test_gru_unit_op

* [ROCM] fix test_imperative_optimizer

* [ROCM] fix rnn

* [ROCM] fix group_norm_op

* [ROCM] fix test_pool3d_api

* [ROCM] fix test_pool3d_op
上级 d8afe407
...@@ -39,7 +39,11 @@ limitations under the License. */ ...@@ -39,7 +39,11 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
#else
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
#define BLOCK_X 32 #define BLOCK_X 32
#define BLOCK_Y 32 #define BLOCK_Y 32
#endif #endif
......
...@@ -174,7 +174,11 @@ class GroupNormKernel<platform::CUDADeviceContext, T> ...@@ -174,7 +174,11 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]); : x_dims[1] * x_dims[2]);
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize); int block_size = std::min(1024, imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]); dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>( GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
...@@ -348,7 +352,11 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -348,7 +352,11 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]); : x_dims[1] * x_dims[2]);
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize); int block_size = std::min(1024, imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]); dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
int flags = int flags =
......
...@@ -75,10 +75,11 @@ class ScopedRNNBase { ...@@ -75,10 +75,11 @@ class ScopedRNNBase {
dropout_state, seed_, state_size); dropout_state, seed_, state_size);
// ------------------- miopen rnn descriptors --------------------- // ------------------- miopen rnn descriptors ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetRNNDescriptor_V2(
rnn_desc_.desc(), hidden_size_, num_layers_, miopenRNNlinear, rnn_desc_.desc(), hidden_size_, num_layers_, dropout_desc_.desc(),
miopenRNNlinear,
is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, miopenLSTM, is_bidirec_ ? miopenRNNbidirection : miopenRNNunidirection, miopenLSTM,
miopenRNNNoBias, miopenRNNdefault, miopen_type)); miopenRNNwithBias, miopenRNNdefault, miopen_type));
// ------------------- miopen weights_size --------------------- // ------------------- miopen weights_size ---------------------
size_t weights_size_; size_t weights_size_;
......
...@@ -434,9 +434,10 @@ class ScopedPoolingDescriptor { ...@@ -434,9 +434,10 @@ class ScopedPoolingDescriptor {
"The size of kernel and strides should be equal. But " "The size of kernel and strides should be equal. But "
"received size of kernel is %d, size of strides is %d.", "received size of kernel is %d, size of strides is %d.",
kernel.size(), strides.size())); kernel.size(), strides.size()));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet2dPoolingDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetNdPoolingDescriptor(
desc_, GetPoolingMode(mode), kernel[0], kernel[1], pads[0], pads[1], desc_, GetPoolingMode(mode), kernel.size(),
strides[0], strides[1])); const_cast<int*>(kernel.data()), const_cast<int*>(pads.data()),
const_cast<int*>(strides.data())));
return desc_; return desc_;
} }
......
...@@ -42,11 +42,12 @@ class TestBilinearTensorProductOp(OpTest): ...@@ -42,11 +42,12 @@ class TestBilinearTensorProductOp(OpTest):
size0 = 5 size0 = 5
size1 = 4 size1 = 4
size2 = 5 size2 = 5
a = np.random.random((batch_size, size0)).astype("float64") dtype = "float32" if fluid.core.is_compiled_with_rocm() else "float64"
b = np.random.random((batch_size, size1)).astype("float64") a = np.random.random((batch_size, size0)).astype(dtype)
w = np.random.random((size2, size0, size1)).astype("float64") b = np.random.random((batch_size, size1)).astype(dtype)
bias = np.random.random((1, size2)).astype("float64") w = np.random.random((size2, size0, size1)).astype(dtype)
output = np.zeros((batch_size, size2)).astype("float64") bias = np.random.random((1, size2)).astype(dtype)
output = np.zeros((batch_size, size2)).astype(dtype)
for i in range(size2): for i in range(size2):
w_i = w[i, :, :] w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
......
...@@ -30,7 +30,7 @@ class TestConvDoubleGradCheck(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestConvDoubleGradCheck(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 4, 3, 3] shape = [2, 4, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 2, 1, groups=1, bias_attr=False) y = layers.conv2d(x, 2, 1, groups=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
...@@ -57,7 +57,7 @@ class TestConvDoubleGradCheck(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestConvDoubleGradCheck(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 4, 3, 3] shape = [2, 4, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 2, 1, bias_attr=False) y = layers.conv2d(x, 2, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
...@@ -82,7 +82,7 @@ class TestConvDoubleGradCheckTest1(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestConvDoubleGradCheckTest1(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 3, 3, 3] shape = [2, 3, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 2, 1, padding=1, bias_attr=False) y = layers.conv2d(x, 2, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
...@@ -107,7 +107,7 @@ class TestConv3DDoubleGradCheck(unittest.TestCase): ...@@ -107,7 +107,7 @@ class TestConv3DDoubleGradCheck(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 4, 3, 4, 2] shape = [2, 4, 3, 4, 2]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 2, 1, bias_attr=False) y = layers.conv3d(x, 2, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
...@@ -132,7 +132,7 @@ class TestConv3DDoubleGradCheckTest1(unittest.TestCase): ...@@ -132,7 +132,7 @@ class TestConv3DDoubleGradCheckTest1(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 4, 5, 3, 2] shape = [2, 4, 5, 3, 2]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 2, 1, padding=1, bias_attr=False) y = layers.conv3d(x, 2, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
...@@ -157,7 +157,7 @@ class TestConv2DoubleGradCheck_AsyPadding(unittest.TestCase): ...@@ -157,7 +157,7 @@ class TestConv2DoubleGradCheck_AsyPadding(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3] shape = [2, 2, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d( y = layers.conv2d(
input=x, input=x,
...@@ -188,7 +188,7 @@ class TestConv2DoubleGradCheck_PaddingSAME(unittest.TestCase): ...@@ -188,7 +188,7 @@ class TestConv2DoubleGradCheck_PaddingSAME(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3] shape = [2, 2, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d( y = layers.conv2d(
input=x, input=x,
...@@ -219,7 +219,7 @@ class TestConv2DoubleGradCheck_PaddingVALID(unittest.TestCase): ...@@ -219,7 +219,7 @@ class TestConv2DoubleGradCheck_PaddingVALID(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3] shape = [2, 2, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d( y = layers.conv2d(
input=x, input=x,
...@@ -250,7 +250,7 @@ class TestConv2DoubleGradCheck_ChannelLast(unittest.TestCase): ...@@ -250,7 +250,7 @@ class TestConv2DoubleGradCheck_ChannelLast(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3] shape = [2, 2, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d( y = layers.conv2d(
input=x, input=x,
...@@ -283,7 +283,7 @@ class TestConv2DoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase): ...@@ -283,7 +283,7 @@ class TestConv2DoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3] shape = [2, 2, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv2d( y = layers.conv2d(
input=x, input=x,
...@@ -316,7 +316,7 @@ class TestConv3DDoubleGradCheck_AsyPadding(unittest.TestCase): ...@@ -316,7 +316,7 @@ class TestConv3DDoubleGradCheck_AsyPadding(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 2, 2, 2] shape = [2, 2, 2, 2, 2]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d( y = layers.conv3d(
input=x, input=x,
...@@ -347,7 +347,7 @@ class TestConv3DoubleGradCheck_PaddingSAME(unittest.TestCase): ...@@ -347,7 +347,7 @@ class TestConv3DoubleGradCheck_PaddingSAME(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 2, 2, 2] shape = [2, 2, 2, 2, 2]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d( y = layers.conv3d(
input=x, input=x,
...@@ -379,7 +379,7 @@ class TestConv3DoubleGradCheck_PaddingVALID(unittest.TestCase): ...@@ -379,7 +379,7 @@ class TestConv3DoubleGradCheck_PaddingVALID(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 3, 3, 2] shape = [2, 2, 3, 3, 2]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d( y = layers.conv3d(
input=x, input=x,
...@@ -410,7 +410,7 @@ class TestConv3DDoubleGradCheck_ChannelLast(unittest.TestCase): ...@@ -410,7 +410,7 @@ class TestConv3DDoubleGradCheck_ChannelLast(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 2, 2, 3] shape = [2, 2, 2, 2, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d( y = layers.conv3d(
input=x, input=x,
...@@ -443,7 +443,7 @@ class TestConv3DDoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase): ...@@ -443,7 +443,7 @@ class TestConv3DDoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 2, 2, 2, 3] shape = [2, 2, 2, 2, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
y = layers.conv3d( y = layers.conv3d(
input=x, input=x,
...@@ -476,7 +476,7 @@ class TestDepthWiseConvDoubleGradCheck(unittest.TestCase): ...@@ -476,7 +476,7 @@ class TestDepthWiseConvDoubleGradCheck(unittest.TestCase):
def func(self, place): def func(self, place):
shape = [2, 4, 3, 3] shape = [2, 4, 3, 3]
eps = 0.005 eps = 0.005
dtype = np.float64 dtype = np.float32 if fluid.core.is_compiled_with_rocm() else np.float64
x = layers.data('x', shape, False, dtype) x = layers.data('x', shape, False, dtype)
# condition of depthwise conv: # condition of depthwise conv:
......
...@@ -27,8 +27,10 @@ class TestExpandOpRank1(OpTest): ...@@ -27,8 +27,10 @@ class TestExpandOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.init_data() self.init_data()
self.dtype = "float32" if fluid.core.is_compiled_with_rocm(
) else "float64"
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")} self.inputs = {'X': np.random.random(self.ori_shape).astype(self.dtype)}
self.attrs = {'expand_times': self.expand_times} self.attrs = {'expand_times': self.expand_times}
output = np.tile(self.inputs['X'], self.expand_times) output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -79,13 +81,16 @@ class TestExpandOpRank1_tensor_attr(OpTest): ...@@ -79,13 +81,16 @@ class TestExpandOpRank1_tensor_attr(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.init_data() self.init_data()
self.dtype = "float32" if fluid.core.is_compiled_with_rocm(
) else "float64"
expand_times_tensor = [] expand_times_tensor = []
for index, ele in enumerate(self.expand_times): for index, ele in enumerate(self.expand_times):
expand_times_tensor.append(("x" + str(index), np.ones( expand_times_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele)) (1)).astype('int32') * ele))
self.inputs = { self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"), 'X': np.random.random(self.ori_shape).astype(self.dtype),
'expand_times_tensor': expand_times_tensor, 'expand_times_tensor': expand_times_tensor,
} }
self.attrs = {"expand_times": self.infer_expand_times} self.attrs = {"expand_times": self.infer_expand_times}
...@@ -123,9 +128,11 @@ class TestExpandOpRank1_tensor(OpTest): ...@@ -123,9 +128,11 @@ class TestExpandOpRank1_tensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.init_data() self.init_data()
self.dtype = "float32" if fluid.core.is_compiled_with_rocm(
) else "float64"
self.inputs = { self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"), 'X': np.random.random(self.ori_shape).astype(self.dtype),
'ExpandTimes': np.array(self.expand_times).astype("int32"), 'ExpandTimes': np.array(self.expand_times).astype("int32"),
} }
self.attrs = {} self.attrs = {}
......
...@@ -26,7 +26,10 @@ from paddle.fluid import core ...@@ -26,7 +26,10 @@ from paddle.fluid import core
class TestGPUPackagePaddle(unittest.TestCase): class TestGPUPackagePaddle(unittest.TestCase):
def test_import_paddle(self): def test_import_paddle(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
os.environ['CUDA_VISIBLE_DEVICES'] = '' if core.is_compiled_with_rocm():
os.environ['HIP_VISIBLE_DEVICES'] = ''
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
test_file = 'test_no_gpu_run_rand.py' test_file = 'test_no_gpu_run_rand.py'
with open(test_file, 'w') as wb: with open(test_file, 'w') as wb:
cmd_test = """ cmd_test = """
......
...@@ -44,8 +44,9 @@ class TestGRUOp(OpTest): ...@@ -44,8 +44,9 @@ class TestGRUOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "rnn" self.op_type = "rnn"
self.dtype = "float64" self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
self.sequence_length = np.array( self.sequence_length = None if core.is_compiled_with_rocm(
) else np.array(
[12, 11, 10, 9, 8, 7, 6, 5], dtype=np.int32) [12, 11, 10, 9, 8, 7, 6, 5], dtype=np.int32)
self.num_layers = 1 self.num_layers = 1
self.is_bidirec = False self.is_bidirec = False
...@@ -83,6 +84,24 @@ class TestGRUOp(OpTest): ...@@ -83,6 +84,24 @@ class TestGRUOp(OpTest):
output, last_hidden = rnn1(input, sequence_length=self.sequence_length) output, last_hidden = rnn1(input, sequence_length=self.sequence_length)
if core.is_compiled_with_rocm():
def rocm_rnn_get_place():
places = [core.CUDAPlace(0)]
return places
self._get_places = rocm_rnn_get_place
if self.is_bidirec:
for i in range(0, len(flat_w), 4):
flat_w[i + 1], flat_w[i + 2] = flat_w[i + 2], flat_w[i + 1]
for i in range(len(flat_w)):
w = np.split(flat_w[i][1], 3, 0)
w = [w[1], w[0], w[2]]
w = np.concatenate(w)
flat_w[i] = (flat_w[i][0], w)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size, init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
self.hidden_size)).astype(self.dtype) self.hidden_size)).astype(self.dtype)
......
...@@ -121,12 +121,12 @@ class TestGRUUnitOp(OpTest): ...@@ -121,12 +121,12 @@ class TestGRUUnitOp(OpTest):
self.op_type = 'gru_unit' self.op_type = 'gru_unit'
self.inputs = { self.inputs = {
'Input': np.random.uniform( 'Input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float64'), -0.1, 0.1, (batch_size, frame_size * 3)).astype(self.dtype),
'HiddenPrev': np.random.uniform( 'HiddenPrev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype('float64'), -0.1, 0.1, (batch_size, frame_size)).astype(self.dtype),
'Weight': np.random.uniform( 'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype('float64'), (frame_size, frame_size * 3)).astype(self.dtype),
} }
self.attrs = { self.attrs = {
'activation': GRUActivationType.tanh, 'activation': GRUActivationType.tanh,
...@@ -161,12 +161,14 @@ class TestGRUUnitOp(OpTest): ...@@ -161,12 +161,14 @@ class TestGRUUnitOp(OpTest):
else: else:
h = u * c + (1 - u) * h_p h = u * c + (1 - u) * h_p
self.outputs = { self.outputs = {
'Gate': g.astype('float64'), 'Gate': g.astype(self.dtype),
'ResetHiddenPrev': r_h_p.astype('float64'), 'ResetHiddenPrev': r_h_p.astype(self.dtype),
'Hidden': h.astype('float64') 'Hidden': h.astype(self.dtype)
} }
def setUp(self): def setUp(self):
self.dtype = 'float32' if fluid.core.is_compiled_with_rocm(
) else 'float64'
self.set_inputs() self.set_inputs()
self.set_outputs() self.set_outputs()
...@@ -179,6 +181,8 @@ class TestGRUUnitOp(OpTest): ...@@ -179,6 +181,8 @@ class TestGRUUnitOp(OpTest):
class TestGRUUnitOpOriginMode(TestGRUUnitOp): class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self): def setUp(self):
self.dtype = 'float32' if fluid.core.is_compiled_with_rocm(
) else 'float64'
self.set_inputs(origin_mode=True) self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True) self.set_outputs(origin_mode=True)
...@@ -189,7 +193,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -189,7 +193,7 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
frame_size = self.frame_size frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs() super(TestGRUUnitOpWithBias, self).set_inputs()
self.inputs['Bias'] = np.random.uniform( self.inputs['Bias'] = np.random.uniform(
-0.1, 0.1, (1, frame_size * 3)).astype('float64') -0.1, 0.1, (1, frame_size * 3)).astype(self.dtype)
self.attrs = { self.attrs = {
'activation': GRUActivationType.identity, 'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid, 'gate_activation': GRUActivationType.sigmoid,
...@@ -207,6 +211,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -207,6 +211,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias): class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self): def setUp(self):
self.dtype = 'float32' if fluid.core.is_compiled_with_rocm(
) else 'float64'
self.set_inputs(origin_mode=True) self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True) self.set_outputs(origin_mode=True)
......
...@@ -190,10 +190,18 @@ class TestImperativeOptimizerBase(unittest.TestCase): ...@@ -190,10 +190,18 @@ class TestImperativeOptimizerBase(unittest.TestCase):
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key])) self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out)) if core.is_compiled_with_rocm():
self.assertTrue(np.allclose(static_out, dy_out, atol=1e-3))
else:
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key])) if core.is_compiled_with_rocm():
self.assertTrue(
np.allclose(
value, dy_param_value[key], atol=1e-3))
else:
self.assertTrue(np.allclose(value, dy_param_value[key]))
class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase): class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase):
......
...@@ -207,10 +207,18 @@ class TestImperativeOptimizerBase(unittest.TestCase): ...@@ -207,10 +207,18 @@ class TestImperativeOptimizerBase(unittest.TestCase):
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key])) self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out)) if core.is_compiled_with_rocm():
self.assertTrue(np.allclose(static_out, dy_out, atol=1e-3))
else:
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key])) if core.is_compiled_with_rocm():
self.assertTrue(
np.allclose(
value, dy_param_value[key], atol=1e-3))
else:
self.assertTrue(np.allclose(value, dy_param_value[key]))
class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase): class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase):
......
...@@ -390,8 +390,10 @@ class TestCUDNNLstmOp(OpTest): ...@@ -390,8 +390,10 @@ class TestCUDNNLstmOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cudnn_lstm" self.op_type = "cudnn_lstm"
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) self.sequence_length = None if core.is_compiled_with_rocm(
) else np.array(
[12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1 self.num_layers = 1
self.set_attrs() self.set_attrs()
...@@ -447,6 +449,13 @@ class TestCUDNNLstmOp(OpTest): ...@@ -447,6 +449,13 @@ class TestCUDNNLstmOp(OpTest):
hidden_size)).astype(self.dtype) hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8") state_out = np.ndarray((300)).astype("uint8")
if core.is_compiled_with_rocm():
for i in range(len(flat_w)):
w = np.split(flat_w[i][1], 4, 0)
w = [w[0], w[1], w[3], w[2]]
w = np.concatenate(w)
flat_w[i] = (flat_w[i][0], w)
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
'WeightList': flat_w, 'WeightList': flat_w,
...@@ -454,6 +463,13 @@ class TestCUDNNLstmOp(OpTest): ...@@ -454,6 +463,13 @@ class TestCUDNNLstmOp(OpTest):
'InitC': init_c, 'InitC': init_c,
'SequenceLength': self.sequence_length 'SequenceLength': self.sequence_length
} }
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'InitH': init_h,
'InitC': init_c,
}
self.attrs = { self.attrs = {
'dropout_prob': 0.0, 'dropout_prob': 0.0,
'is_bidirec': False, 'is_bidirec': False,
...@@ -474,8 +490,12 @@ class TestCUDNNLstmOp(OpTest): ...@@ -474,8 +490,12 @@ class TestCUDNNLstmOp(OpTest):
def test_output_with_place(self): def test_output_with_place(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place( if core.is_compiled_with_rocm():
place, no_check_set=['Reserve', 'StateOut']) self.check_output_with_place(
place, atol=1e-5, no_check_set=['Reserve', 'StateOut'])
else:
self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'])
def test_grad_with_place(self): def test_grad_with_place(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -496,14 +516,13 @@ class TestCUDNNlstmAPI(unittest.TestCase): ...@@ -496,14 +516,13 @@ class TestCUDNNlstmAPI(unittest.TestCase):
hidden_size = 20 hidden_size = 20
dropout_prob = 0.0 dropout_prob = 0.0
num_layers = 1 num_layers = 1
dtype = 'float32' if core.is_compiled_with_rocm() else 'float64'
input = fluid.data( input = fluid.data(
name='input', name='input', shape=[seq_len, batch_size, hidden_size], dtype=dtype)
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size], init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0) dtype, 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size], init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0) dtype, 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len, rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers, hidden_size, num_layers,
dropout_prob, False) dropout_prob, False)
...@@ -526,14 +545,13 @@ class TestCUDNNlstmAPI(unittest.TestCase): ...@@ -526,14 +545,13 @@ class TestCUDNNlstmAPI(unittest.TestCase):
hidden_size = 20 hidden_size = 20
dropout_prob = 0.0 dropout_prob = 0.0
num_layers = 2 num_layers = 2
dtype = 'float32' if core.is_compiled_with_rocm() else 'float64'
input = fluid.data( input = fluid.data(
name='input', name='input', shape=[seq_len, batch_size, hidden_size], dtype=dtype)
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size], init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0) dtype, 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size], init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0) dtype, 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len, rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers, hidden_size, num_layers,
dropout_prob, False, True) dropout_prob, False, True)
...@@ -541,7 +559,7 @@ class TestCUDNNlstmAPI(unittest.TestCase): ...@@ -541,7 +559,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
input_i = np.random.uniform( input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size, low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64") hidden_size)).astype(dtype)
out = exe.run(fluid.default_main_program(), out = exe.run(fluid.default_main_program(),
feed={'input': input_i}, feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0']) fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
......
...@@ -224,7 +224,7 @@ class TestPool3D_Op(OpTest): ...@@ -224,7 +224,7 @@ class TestPool3D_Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "pool3d" self.op_type = "pool3d"
self.init_kernel_type() self.init_kernel_type()
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.init_test_case() self.init_test_case()
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
self.init_paddings() self.init_paddings()
...@@ -277,9 +277,16 @@ class TestPool3D_Op(OpTest): ...@@ -277,9 +277,16 @@ class TestPool3D_Op(OpTest):
return return
if self.has_cudnn() and self.pool_type != "max": if self.has_cudnn() and self.pool_type != "max":
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, set(['X']), 'Out') if core.is_compiled_with_rocm():
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=1e-2)
else:
self.check_grad_with_place(place, set(['X']), 'Out')
elif self.pool_type != "max": elif self.pool_type != "max":
self.check_grad(set(['X']), 'Out') if core.is_compiled_with_rocm():
self.check_grad(set(['X']), 'Out', max_relative_error=1e-2)
else:
self.check_grad(set(['X']), 'Out')
def init_data_format(self): def init_data_format(self):
self.data_format = "NCDHW" self.data_format = "NCDHW"
...@@ -400,7 +407,10 @@ def create_test_cudnn_fp16_class(parent): ...@@ -400,7 +407,10 @@ def create_test_cudnn_fp16_class(parent):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3) if core.is_compiled_with_rocm():
self.check_output_with_place(place, atol=1e-2)
else:
self.check_output_with_place(place, atol=1e-3)
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16Op") cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16Op")
TestCUDNNFp16Case.__name__ = cls_name TestCUDNNFp16Case.__name__ = cls_name
......
...@@ -1053,7 +1053,8 @@ class RNNBase(LayerList): ...@@ -1053,7 +1053,8 @@ class RNNBase(LayerList):
initial_states, initial_states,
paddle.fluid.framework.Variable) else initial_states paddle.fluid.framework.Variable) else initial_states
if self.could_use_cudnn: if self.could_use_cudnn and (not fluid.core.is_compiled_with_rocm() or
sequence_length is None):
# Add CPU kernel and dispatch in backend later # Add CPU kernel and dispatch in backend later
return self._cudnn_impl(inputs, initial_states, sequence_length) return self._cudnn_impl(inputs, initial_states, sequence_length)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册