未验证 提交 8082ba8a 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix] fix compute error in fused_dropout_add (#52261)

* fix bg

* add utest

* add utest
上级 73df2b1e
...@@ -91,9 +91,9 @@ struct NoMaskBwFunctor { ...@@ -91,9 +91,9 @@ struct NoMaskBwFunctor {
template <typename T, typename Functor> template <typename T, typename Functor>
__global__ void VectorizedDropoutBackward(const size_t n, __global__ void VectorizedDropoutBackward(const size_t n,
uint64_t seed, uint64_t seed,
T* src, T* x,
T* res, T* y,
const T* dst, const T* out_grad,
uint64_t increment, uint64_t increment,
size_t main_offset, size_t main_offset,
Functor functor) { Functor functor) {
...@@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n, ...@@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n,
#endif #endif
float rands[kCount]; float rands[kCount];
T src_res[kCount * 2]; T x_y[kCount * 2];
T res_grad[kCount];
using Rand = phi::funcs::uniform_distribution<float>; using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>; using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount; int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount; size_t fix = idx * kCount;
for (; fix < main_offset; fix += stride) { for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&src_res[0], dst, deal_size); kps::ReadData<T, kCount, 1, false>(&x_y[0], out_grad + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>( kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount); &x_y[0], &x_y[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(src + fix, &src_res[0], deal_size);
// res kps::WriteData<T, kCount, 1, false>(x + fix, &x_y[0], deal_size);
kps::ElementwiseUnary<T, T, kCount, 1, Cast>( kps::WriteData<T, kCount, 1, false>(y + fix, &x_y[kCount], deal_size);
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, false>(res + fix, &res_grad[0], deal_size);
if (fix > idx * kCount + 1) { if (fix > idx * kCount + 1) {
__syncthreads(); __syncthreads();
} }
} }
int remainder = n - fix; int remainder = n - fix;
if (remainder > 0) { if (remainder > 0) {
kps::ReadData<T, kCount, 1, true>(&src_res[0], dst + fix, remainder); kps::ReadData<T, kCount, 1, true>(&x_y[0], out_grad + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>( kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount); &x_y[0], &x_y[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, true>(src + fix, &src_res[0], remainder);
// res kps::WriteData<T, kCount, 1, true>(x + fix, &x_y[0], remainder);
kps::ElementwiseUnary<T, T, kCount, 1, Cast>( kps::WriteData<T, kCount, 1, true>(y + fix, &x_y[kCount], remainder);
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, true>(res + fix, &res_grad[0], remainder);
__syncthreads(); __syncthreads();
} }
} }
...@@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, ...@@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
size_t block_size = random_prop[1]; size_t block_size = random_prop[1];
size_t offset = random_prop[2]; size_t offset = random_prop[2];
size_t main_offset = random_prop[3]; size_t main_offset = random_prop[3];
auto functor = upscale_in_train auto functor = upscale_in_train
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate) ? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f); : NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);
......
...@@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"): ...@@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
) )
class TestFusedDropoutAdd(unittest.TestCase): class TestFusedDropoutAdd(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = (2, 10, 10, 2) self.shape = [2, 1024, 2, 1]
self.dtype = 'float64' self.dtype = 'float16'
self.dropout_rate = 0.9 self.dropout_rate = 0.5
self.training = True self.training = True
self.mode = "upscale_in_train" self.mode = "upscale_in_train"
self.seed = 1027 self.seed = 1027
...@@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase): ...@@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase):
mode=self.mode, mode=self.mode,
) )
fw.append(out) fw.append(out)
out_g = paddle.randn(self.shape, self.dtype)
loss = paddle.mean(out) paddle.autograd.backward([out], [out_g], True)
loss.backward()
for i in range(count): for i in range(count):
bw.append(data[i].grad) bw.append(data[i].grad)
return fw, bw return fw, bw
...@@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed): ...@@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed):
) )
class TestFusedDropoutAddCase(parent): class TestFusedDropoutAddCase(parent):
def setUp(self): def setUp(self):
self.shape = (2, 10, 10, 2) self.shape = (2, 1024, 1, 1)
self.dtype = dtype self.dtype = dtype
self.dropout_rate = p self.dropout_rate = p
self.training = training self.training = training
...@@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase): ...@@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase):
y = paddle.randn(self.shape, self.dtype) y = paddle.randn(self.shape, self.dtype)
fused_d_a = FusedDropoutAdd(p=0.5) fused_d_a = FusedDropoutAdd(p=0.5)
d = paddle.nn.Dropout(p=0.5) d = paddle.nn.Dropout(p=0.5)
print(d) print(d.extra_repr())
paddle.seed(2048) paddle.seed(2048)
fused_out = fused_d_a(x, y) fused_out = fused_d_a(x, y)
paddle.seed(2048) paddle.seed(2048)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册