diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index 59ba552f560dab904d4983e0778ff57be9477c3e..b17290557c4f635a963d88525409b9373b057a4b 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_base.h" -#include "hl_sparse.ph" -#include "hl_top_k.h" +#include "paddle/cuda/include/hl_base.h" +#include "paddle/cuda/include/hl_sparse.ph" +#include "paddle/cuda/include/hl_top_k.h" #include "paddle/utils/Logging.h" // using namespace hppl; @@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, if (--beamSize == 0) break; __syncthreads(); + // NOTE(zcd): temporary solution unsigned mask = 0u; - // CREATE_SHFL_MASK(mask, tid < len); + CREATE_SHFL_MASK(mask, true); if (tid == maxId[0]) { if (beam < maxLength) { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 2ea9fd1d29936357bbb5f6819c813e4bd1ebe44c..faaae1f9b6ae1e2c577017c34116dab95e261ab2 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { namespace operators { @@ -235,8 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, sh_topk[tid] = topk[*beam]; } } + // NOTE(zcd): temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (maxid[0] / 32 == warp) { - if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; + if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) + break; } } } diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 8758af0804ae08fec6fa64d7387f197f046ce20e..d535ed2f89df6a0b311ec068ecd92c8e3183cee7 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -65,6 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } #endif - } // namespace platform } // namespace paddle