diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index 10c670751d026ef92e01aad7da31a8f59b8514c0..c31d2195e95b116451b0f620f6582f65c0dae706 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel { auto M = EigenMatrix::Reshape(*mask, 1); Y.device(place) = X * M; } else { - Y.device(place) = X * dropout_prob; + Y.device(place) = X * (1.0f - dropout_prob); } } }; diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index 84ad39f0bb639975365d427aa205411ef79ecd46..9f6c4212d4f834bbb2d1c65c836f3d3d0f3e0c96 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - Y.device(place) = X * dropout_prob; + Y.device(place) = X * (1.0f - dropout_prob); } } }; diff --git a/python/paddle/v2/fluid/tests/test_dropout_op.py b/python/paddle/v2/fluid/tests/test_dropout_op.py index 4f5ea836b44102e5599a2302efd669291ebe920b..2483200212686caf9c46f9c1129b5d8ffdcc9145 100644 --- a/python/paddle/v2/fluid/tests/test_dropout_op.py +++ b/python/paddle/v2/fluid/tests/test_dropout_op.py @@ -47,7 +47,9 @@ class TestDropoutOp4(OpTest): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.attrs = {'dropout_prob': 0.35, 'is_test': True} - self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + self.outputs = { + 'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob']) + } def test_check_output(self): self.check_output() @@ -58,7 +60,9 @@ class TestDropoutOp5(OpTest): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} self.attrs = {'dropout_prob': 0.75, 'is_test': True} - self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + self.outputs = { + 'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob']) + } def test_check_output(self): self.check_output()