提交 b8f7fa97 编写于 作者: C chengduoZH

replace __shfl with __shfl_sync

上级 7c90d7a3
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "hl_base.h" #include "paddle/cuda/include/hl_base.h"
#include "hl_sparse.ph" #include "paddle/cuda/include/hl_sparse.ph"
#include "hl_top_k.h" #include "paddle/cuda/include/hl_top_k.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
// using namespace hppl; // using namespace hppl;
...@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, ...@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break; if (--beamSize == 0) break;
__syncthreads(); __syncthreads();
// temporary solution
unsigned mask = 0u; unsigned mask = 0u;
// CREATE_SHFL_MASK(mask, tid < len); CREATE_SHFL_MASK(mask, true);
if (tid == maxId[0]) { if (tid == maxId[0]) {
if (beam < maxLength) { if (beam < maxLength) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, ...@@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk[tid] = topk[*beam]; sh_topk[tid] = topk[*beam];
} }
} }
// temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) { if (maxid[0] / 32 == warp) {
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break;
} }
} }
} }
......
...@@ -72,6 +72,13 @@ template <typename T> ...@@ -72,6 +72,13 @@ template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta); return __shfl_down(val, delta);
} }
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else #else
#define FULL_WARP_MASK 0xFFFFFFFF #define FULL_WARP_MASK 0xFFFFFFFF
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册