diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index e6fb4b97d66e153ea38d50e0f12dab95ecddda75..513e3d035c130e75551c06dd8bf866b830e355b5 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -91,9 +91,9 @@ struct NoMaskBwFunctor { template __global__ void VectorizedDropoutBackward(const size_t n, uint64_t seed, - T* src, - T* res, - const T* dst, + T* x, + T* y, + const T* out_grad, uint64_t increment, size_t main_offset, Functor functor) { @@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n, #endif float rands[kCount]; - T src_res[kCount * 2]; - T res_grad[kCount]; + T x_y[kCount * 2]; using Rand = phi::funcs::uniform_distribution; using Cast = kps::IdentityFunctor; int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; + for (; fix < main_offset; fix += stride) { - kps::ReadData(&src_res[0], dst, deal_size); + kps::ReadData(&x_y[0], out_grad + fix, deal_size); kps::ElementwiseRandom( &rands[0], Rand(), &state); - // x_grad kps::OperatorTernary( - &src_res[0], &src_res[0], &rands[0], functor, kCount); - kps::WriteData(src + fix, &src_res[0], deal_size); - // res - kps::ElementwiseUnary( - &res_grad[0], &src_res[kCount], Cast()); - kps::WriteData(res + fix, &res_grad[0], deal_size); + &x_y[0], &x_y[0], &rands[0], functor, kCount); + + kps::WriteData(x + fix, &x_y[0], deal_size); + kps::WriteData(y + fix, &x_y[kCount], deal_size); if (fix > idx * kCount + 1) { __syncthreads(); } } + int remainder = n - fix; if (remainder > 0) { - kps::ReadData(&src_res[0], dst + fix, remainder); + kps::ReadData(&x_y[0], out_grad + fix, remainder); kps::ElementwiseRandom( &rands[0], Rand(), &state); - // x_grad kps::OperatorTernary( - &src_res[0], &src_res[0], &rands[0], functor, kCount); - kps::WriteData(src + fix, &src_res[0], remainder); + &x_y[0], &x_y[0], &rands[0], functor, kCount); - // res - kps::ElementwiseUnary( - &res_grad[0], &src_res[kCount], Cast()); - kps::WriteData(res + fix, &res_grad[0], remainder); + kps::WriteData(x + fix, &x_y[0], remainder); + kps::WriteData(y + fix, &x_y[kCount], remainder); __syncthreads(); } } @@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, size_t block_size = random_prop[1]; size_t offset = random_prop[2]; size_t main_offset = random_prop[3]; - auto functor = upscale_in_train ? NoMaskBwFunctor(1.0f - dropout_rate) : NoMaskBwFunctor(1.0f - dropout_rate, 1.0f); diff --git a/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py b/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py index 562765c1c2652287ebba30da327b1eefa13a85dc..792f0b2e877d088057701259aa8e63c5d7518ca9 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py @@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"): ) class TestFusedDropoutAdd(unittest.TestCase): def setUp(self): - self.shape = (2, 10, 10, 2) - self.dtype = 'float64' - self.dropout_rate = 0.9 + self.shape = [2, 1024, 2, 1] + self.dtype = 'float16' + self.dropout_rate = 0.5 self.training = True self.mode = "upscale_in_train" self.seed = 1027 @@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase): mode=self.mode, ) fw.append(out) - - loss = paddle.mean(out) - loss.backward() + out_g = paddle.randn(self.shape, self.dtype) + paddle.autograd.backward([out], [out_g], True) for i in range(count): bw.append(data[i].grad) return fw, bw @@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed): ) class TestFusedDropoutAddCase(parent): def setUp(self): - self.shape = (2, 10, 10, 2) + self.shape = (2, 1024, 1, 1) self.dtype = dtype self.dropout_rate = p self.training = training @@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase): y = paddle.randn(self.shape, self.dtype) fused_d_a = FusedDropoutAdd(p=0.5) d = paddle.nn.Dropout(p=0.5) - print(d) + print(d.extra_repr()) paddle.seed(2048) fused_out = fused_d_a(x, y) paddle.seed(2048)