未验证 提交 910f377f 编写于 作者: F furnace 提交者: GitHub

Bugfix rocm (#31490)

* bugfix for test_cholesky_op

* bugfix for test_compare_op

* bugfix for lookup_table_op

* bugfix for affine_channel_op
上级 416e47ed
...@@ -71,7 +71,11 @@ class AffineChannelCUDAKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,11 @@ class AffineChannelCUDAKernel : public framework::OpKernel<T> {
const T* bias_d = bias->data<T>(); const T* bias_d = bias->data<T>();
T* y_d = y->data<T>(); T* y_d = y->data<T>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024; int block = 1024;
#endif // PADDLE_WITH_HIP
int grid = (num + block - 1) / block; int grid = (num + block - 1) / block;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
...@@ -153,7 +157,11 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> { ...@@ -153,7 +157,11 @@ class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr; T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr; T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
#ifdef PADDLE_WITH_HIP
const int block = 256;
#else
const int block = 1024; const int block = 1024;
#endif // PADDLE_WITH_HIP
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1); const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block; int grid1 = (num + block - 1) / block;
......
...@@ -105,9 +105,24 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -105,9 +105,24 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto *table = table_t->data<T>(); auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_HIP
dim3 threads(64, 4);
#else
dim3 threads(128, 8); dim3 threads(128, 8);
#endif // PADDLE_WITH_HIP
dim3 grids(8, 1); dim3 grids(8, 1);
#ifdef PADDLE_WITH_HIP
if (padding_idx == -1)
LookupTable<
T, 64, 4, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 64, 4, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
#else
if (padding_idx == -1) if (padding_idx == -1)
LookupTable< LookupTable<
T, 128, 8, 8, T, 128, 8, 8,
...@@ -118,6 +133,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -118,6 +133,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
T, 128, 8, 8, T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>( true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx); output, table, ids, N, K, D, padding_idx);
#endif // PADDLE_WITH_HIP
} }
}; };
...@@ -185,10 +201,20 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -185,10 +201,20 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto t = framework::EigenVector<T>::Flatten(*d_table_t); auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0)); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
#ifdef PADDLE_WITH_HIP
dim3 threads(64, 4);
#else
dim3 threads(128, 8); dim3 threads(128, 8);
#endif // PADDLE_WITH_HIP
dim3 grids(8, 1); dim3 grids(8, 1);
#ifdef PADDLE_WITH_HIP
LookupTableGrad<T, 64, 4, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
#else
LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>( LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
#endif // PADDLE_WITH_HIP
} }
} }
}; };
......
...@@ -58,7 +58,7 @@ class TestCholeskyOp(OpTest): ...@@ -58,7 +58,7 @@ class TestCholeskyOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
for p in places: for p in places:
self.func(p) self.func(p)
...@@ -92,6 +92,9 @@ class TestCholeskyOp2D(TestCholeskyOp): ...@@ -92,6 +92,9 @@ class TestCholeskyOp2D(TestCholeskyOp):
class TestDygraph(unittest.TestCase): class TestDygraph(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
if core.is_compiled_with_rocm():
paddle.disable_static(place=fluid.CPUPlace())
else:
paddle.disable_static() paddle.disable_static()
a = np.random.rand(3, 3) a = np.random.rand(3, 3)
a_t = np.transpose(a, [1, 0]) a_t = np.transpose(a, [1, 0])
...@@ -103,7 +106,7 @@ class TestDygraph(unittest.TestCase): ...@@ -103,7 +106,7 @@ class TestDygraph(unittest.TestCase):
class TestCholeskySingularAPI(unittest.TestCase): class TestCholeskySingularAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.places = [fluid.CPUPlace()] self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
self.places.append(fluid.CUDAPlace(0)) self.places.append(fluid.CUDAPlace(0))
def check_static_result(self, place, with_out=False): def check_static_result(self, place, with_out=False):
......
...@@ -61,6 +61,9 @@ def create_test_class(op_type, typename, callback): ...@@ -61,6 +61,9 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册