未验证 提交 8082ba8a 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix] fix compute error in fused_dropout_add (#52261)

* fix bg

* add utest

* add utest
上级 73df2b1e
......@@ -91,9 +91,9 @@ struct NoMaskBwFunctor {
template <typename T, typename Functor>
__global__ void VectorizedDropoutBackward(const size_t n,
uint64_t seed,
T* src,
T* res,
const T* dst,
T* x,
T* y,
const T* out_grad,
uint64_t increment,
size_t main_offset,
Functor functor) {
......@@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n,
#endif
float rands[kCount];
T src_res[kCount * 2];
T res_grad[kCount];
T x_y[kCount * 2];
using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount;
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&src_res[0], dst, deal_size);
kps::ReadData<T, kCount, 1, false>(&x_y[0], out_grad + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(src + fix, &src_res[0], deal_size);
// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, false>(res + fix, &res_grad[0], deal_size);
&x_y[0], &x_y[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(x + fix, &x_y[0], deal_size);
kps::WriteData<T, kCount, 1, false>(y + fix, &x_y[kCount], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, true>(&src_res[0], dst + fix, remainder);
kps::ReadData<T, kCount, 1, true>(&x_y[0], out_grad + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, true>(src + fix, &src_res[0], remainder);
&x_y[0], &x_y[0], &rands[0], functor, kCount);
// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, true>(res + fix, &res_grad[0], remainder);
kps::WriteData<T, kCount, 1, true>(x + fix, &x_y[0], remainder);
kps::WriteData<T, kCount, 1, true>(y + fix, &x_y[kCount], remainder);
__syncthreads();
}
}
......@@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
size_t block_size = random_prop[1];
size_t offset = random_prop[2];
size_t main_offset = random_prop[3];
auto functor = upscale_in_train
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);
......
......@@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
)
class TestFusedDropoutAdd(unittest.TestCase):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.dtype = 'float64'
self.dropout_rate = 0.9
self.shape = [2, 1024, 2, 1]
self.dtype = 'float16'
self.dropout_rate = 0.5
self.training = True
self.mode = "upscale_in_train"
self.seed = 1027
......@@ -66,9 +66,8 @@ class TestFusedDropoutAdd(unittest.TestCase):
mode=self.mode,
)
fw.append(out)
loss = paddle.mean(out)
loss.backward()
out_g = paddle.randn(self.shape, self.dtype)
paddle.autograd.backward([out], [out_g], True)
for i in range(count):
bw.append(data[i].grad)
return fw, bw
......@@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed):
)
class TestFusedDropoutAddCase(parent):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.shape = (2, 1024, 1, 1)
self.dtype = dtype
self.dropout_rate = p
self.training = training
......@@ -168,7 +167,7 @@ class TestFusedDropoutAddStatic(unittest.TestCase):
y = paddle.randn(self.shape, self.dtype)
fused_d_a = FusedDropoutAdd(p=0.5)
d = paddle.nn.Dropout(p=0.5)
print(d)
print(d.extra_repr())
paddle.seed(2048)
fused_out = fused_d_a(x, y)
paddle.seed(2048)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册