未验证 提交 6e326ca2 编写于 作者: W wangchaochaohu 提交者: GitHub

optimize the realization of cuda dropout (#19136)

* cuda optimie for dropout

* remove tmp swp file

* fix compile error test=develop

* test=develop optimize the cuda realization of dropout op

* remove unsed code test=develop

* remove tmp file test=develop
上级 0865b5a9
......@@ -11,14 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda.h>
#include <curand_kernel.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/dynload/curand.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -27,10 +29,7 @@ __global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1);
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
......@@ -39,12 +38,12 @@ __global__ void RandomGenerator(const size_t n, const int seed,
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
rng.discard(idx);
curand_init(seed, idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
rng.discard(step_size);
curand_init(seed, idx, step_size, &state);
}
if (dist(rng) < dropout_prob) {
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册