diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index 59ba552f560dab904d4983e0778ff57be9477c3e..4a737d5ba7db02b8424299191378a41fb4698821 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_base.h" -#include "hl_sparse.ph" -#include "hl_top_k.h" +#include "paddle/cuda/include/hl_base.h" +#include "paddle/cuda/include/hl_sparse.ph" +#include "paddle/cuda/include/hl_top_k.h" #include "paddle/utils/Logging.h" // using namespace hppl; @@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, if (--beamSize == 0) break; __syncthreads(); + // temporary solution unsigned mask = 0u; - // CREATE_SHFL_MASK(mask, tid < len); + CREATE_SHFL_MASK(mask, true); if (tid == maxId[0]) { if (beam < maxLength) { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index d7f4d383ce0d9e1ff42fc12c96aaf0ceb532e5db..a2e3973fe8d7aa27e2bd8142dd6fffbf6505a2d4 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, sh_topk[tid] = topk[*beam]; } } + // temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (maxid[0] / 32 == warp) { - if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; + if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break; } } } diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 866ff30a8be7be124a72a8dc7e70ef4140ee716a..0f6e6159b63c69096554a717ad383dcecec6a2f6 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -72,6 +72,13 @@ template __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { return __shfl_down(val, delta); } + +template +__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, + int width) { + return __shfl(val, src_line, width); +} + #define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #else #define FULL_WARP_MASK 0xFFFFFFFF