未验证 提交 54797abd 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10347 from chengduoZH/replace___shfl_with__shfl_sync

Wrap __shfl
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "hl_base.h" #include "paddle/cuda/include/hl_base.h"
#include "hl_sparse.ph" #include "paddle/cuda/include/hl_sparse.ph"
#include "hl_top_k.h" #include "paddle/cuda/include/hl_top_k.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
// using namespace hppl; // using namespace hppl;
...@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, ...@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break; if (--beamSize == 0) break;
__syncthreads(); __syncthreads();
// NOTE(zcd): temporary solution
unsigned mask = 0u; unsigned mask = 0u;
// CREATE_SHFL_MASK(mask, tid < len); CREATE_SHFL_MASK(mask, true);
if (tid == maxId[0]) { if (tid == maxId[0]) {
if (beam < maxLength) { if (beam < maxLength) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -235,8 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, ...@@ -235,8 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk[tid] = topk[*beam]; sh_topk[tid] = topk[*beam];
} }
} }
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) { if (maxid[0] / 32 == warp) {
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength)
break;
} }
} }
} }
......
...@@ -65,6 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) { ...@@ -65,6 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return __longlong_as_double(old); return __longlong_as_double(old);
} }
#endif #endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册