提交 b8f7fa97 编写于 作者: C chengduoZH

replace __shfl with __shfl_sync

上级 7c90d7a3
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_sparse.ph"
#include "hl_top_k.h"
#include "paddle/cuda/include/hl_base.h"
#include "paddle/cuda/include/hl_sparse.ph"
#include "paddle/cuda/include/hl_top_k.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
......@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break;
__syncthreads();
// temporary solution
unsigned mask = 0u;
// CREATE_SHFL_MASK(mask, tid < len);
CREATE_SHFL_MASK(mask, true);
if (tid == maxId[0]) {
if (beam < maxLength) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
......@@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk[tid] = topk[*beam];
}
}
// temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) {
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break;
if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break;
}
}
}
......
......@@ -72,6 +72,13 @@ template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册