diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index e26eba68f15a9934a64081fddfffd49086f7faa8..e3d758c3a245e0b129e28de41e0ccb4df66288dd 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -11,14 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include #include #include #include #include #include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/dynload/curand.h" #include "paddle/fluid/platform/float16.h" - namespace paddle { namespace operators { @@ -27,10 +29,7 @@ __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, MaskType* mask_data, T* dst, bool is_upscale_in_train) { - thrust::minstd_rand rng; - rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); - + curandStatePhilox4_32_10_t state; int idx = blockDim.x * blockIdx.x + threadIdx.x; int step_size = 0; @@ -39,12 +38,12 @@ __global__ void RandomGenerator(const size_t n, const int seed, for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; if (step_size == 0) { - rng.discard(idx); + curand_init(seed, idx, idx, &state); step_size = blockDim.x * gridDim.x; } else { - rng.discard(step_size); + curand_init(seed, idx, step_size, &state); } - if (dist(rng) < dropout_prob) { + if (curand_uniform(&state) < dropout_prob) { mask = 0; dest = 0; } else {