未验证 提交 910f377f 编写于 作者: F furnace 提交者: GitHub

Bugfix rocm (#31490)

* bugfix for test_cholesky_op

* bugfix for test_compare_op

* bugfix for lookup_table_op

* bugfix for affine_channel_op
上级 416e47ed
......@@ -71,7 +71,11 @@ class AffineChannelCUDAKernel : public framework::OpKernel<T> {
const T* bias_d = bias->data<T>();
T* y_d = y->data<T>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif // PADDLE_WITH_HIP
int grid = (num + block - 1) / block;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
......@@ -153,7 +157,11 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
#ifdef PADDLE_WITH_HIP
const int block = 256;
#else
const int block = 1024;
#endif // PADDLE_WITH_HIP
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
......
......@@ -105,9 +105,24 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_HIP
dim3 threads(64, 4);
#else
dim3 threads(128, 8);
#endif // PADDLE_WITH_HIP
dim3 grids(8, 1);
#ifdef PADDLE_WITH_HIP
if (padding_idx == -1)
LookupTable<
T, 64, 4, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 64, 4, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
#else
if (padding_idx == -1)
LookupTable<
T, 128, 8, 8,
......@@ -118,6 +133,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
#endif // PADDLE_WITH_HIP
}
};
......@@ -185,10 +201,20 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
#ifdef PADDLE_WITH_HIP
dim3 threads(64, 4);
#else
dim3 threads(128, 8);
#endif // PADDLE_WITH_HIP
dim3 grids(8, 1);
#ifdef PADDLE_WITH_HIP
LookupTableGrad<T, 64, 4, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
#else
LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
#endif // PADDLE_WITH_HIP
}
}
};
......
......@@ -58,7 +58,7 @@ class TestCholeskyOp(OpTest):
def test_check_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
......@@ -92,7 +92,10 @@ class TestCholeskyOp2D(TestCholeskyOp):
class TestDygraph(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
if core.is_compiled_with_rocm():
paddle.disable_static(place=fluid.CPUPlace())
else:
paddle.disable_static()
a = np.random.rand(3, 3)
a_t = np.transpose(a, [1, 0])
x_data = np.matmul(a, a_t) + 1e-03
......@@ -103,7 +106,7 @@ class TestDygraph(unittest.TestCase):
class TestCholeskySingularAPI(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
self.places.append(fluid.CUDAPlace(0))
def check_static_result(self, place, with_out=False):
......
......@@ -61,6 +61,9 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册