diff --git a/paddle/fluid/operators/math/cos_sim_functor.cu b/paddle/fluid/operators/math/cos_sim_functor.cu index 4e6ff5ee0a449b42762748ba1a103876beee01f2..537c7e47155fe9a12196869ceaed84fca198335b 100644 --- a/paddle/fluid/operators/math/cos_sim_functor.cu +++ b/paddle/fluid/operators/math/cos_sim_functor.cu @@ -51,7 +51,7 @@ struct CosSimDyFunctor { T* dy) const { const int block_size = 512; dim3 threads(block_size, 1); - dim3 grid(1, (rows + block_size - 1) / block_size); + dim3 grid((rows + block_size - 1) / block_size, 1); CosSimDyKernel<<>>( x_norm, y_norm, x, y, z, dz, rows, cols, dy); } diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a4fa6f5c898c541120a874f962b0f6a817736510..c4fccdbf862fda8a599869c30ae598573ca367aa 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -81,7 +81,7 @@ template __global__ void SelectedRowsAddTensorKernel(const T* selected_rows, const int64_t* rows, T* tensor_out, int64_t row_numel) { - const int ty = blockIdx.y; + const int ty = blockIdx.x; int tid = threadIdx.x; selected_rows += ty * row_numel; @@ -123,7 +123,7 @@ struct SelectedRowsAddTensor { const int block_size = 256; dim3 threads(block_size, 1); - dim3 grid(1, in1_rows.size()); + dim3 grid(in1_rows.size(), 1); SelectedRowsAddTensorKernel< T, block_size><<>>( in1_data, in1_rows.CUDAData(context.GetPlace()), out_data, @@ -188,7 +188,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows, const int64_t* rows, T* tensor_out, int64_t row_numel) { - const int ty = blockIdx.y; + const int ty = blockIdx.x; int tid = threadIdx.x; selected_rows += ty * row_numel; @@ -221,7 +221,7 @@ struct SelectedRowsAddToTensor { auto* in2_data = input2->data(); const int block_size = 256; dim3 threads(block_size, 1); - dim3 grid(1, in1_rows.size()); + dim3 grid(in1_rows.size(), 1); SelectedRowsAddToTensorKernel< T, block_size><<>>( in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data, @@ -388,7 +388,7 @@ template __global__ void UpdateToTensorKernel(const T* selected_rows, const int64_t* rows, const ScatterOps& op, T* tensor_out, int64_t row_numel) { - const int ty = blockIdx.y; + const int ty = blockIdx.x; int tid = threadIdx.x; selected_rows += ty * row_numel; @@ -457,7 +457,7 @@ struct UpdateToTensor { auto* in2_data = input2->data(); dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); - dim3 grid(1, in1_rows.size()); + dim3 grid(in1_rows.size(), 1); UpdateToTensorKernel<<< grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(), op, in2_data, in1_row_numel);