提交 62fed4cb 编写于 作者: C chengduo 提交者: dzhwinter

fix __shfl_down (#10362)

上级 3000e994
...@@ -229,6 +229,11 @@ extern __thread cudaStream_t default_stream; ...@@ -229,6 +229,11 @@ extern __thread cudaStream_t default_stream;
// __shfl has been deprecated as of CUDA 9.0. // __shfl has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T> template <typename T>
__forceinline__ __device__ T __forceinline__ __device__ T
__shfl_sync(unsigned, T val, int src_line, int width) { __shfl_sync(unsigned, T val, int src_line, int width) {
......
...@@ -189,6 +189,10 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -189,6 +189,10 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
} }
__syncthreads(); __syncthreads();
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int i = 0; i < num_sequence; i++) { for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
...@@ -220,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -220,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(0, val, offset); val += platform::__shfl_down_sync(mask, val, offset);
} }
__syncthreads(); __syncthreads();
...@@ -251,6 +255,10 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -251,6 +255,10 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
T *sh_in = mem; T *sh_in = mem;
T *sh_dout = &mem[block_x * block_y]; T *sh_dout = &mem[block_x * block_y];
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int i = 0; i < num_sequence; i++) { for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
...@@ -276,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -276,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(0, val, offset); val += platform::__shfl_down_sync(mask, val, offset);
} }
__syncthreads(); __syncthreads();
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "RowConvOp.h" #include "paddle/cuda/include/hl_base.h"
#include "hl_base.h" #include "paddle/function/RowConvOp.h"
namespace paddle { namespace paddle {
...@@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y, ...@@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y,
} }
template <> template <>
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out, void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out, // NOLINT
const GpuMatrix& in, const GpuMatrix& in,
const GpuMatrix& filter, const GpuMatrix& filter,
const GpuIVector& seq) { const GpuIVector& seq) {
...@@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw, ...@@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw,
} }
__syncthreads(); __syncthreads();
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int i = 0; i < numSeq; ++i) { for (int i = 0; i < numSeq; ++i) {
const int start = starts[i]; const int start = starts[i];
const int end = starts[i + 1]; const int end = starts[i + 1];
...@@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw, ...@@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw,
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t]; real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
__syncthreads(); __syncthreads();
// warp size and blockDim.x is 32. // warp size and blockDim.x is 32.
val += __shfl_down(val, 16);
val += __shfl_down(val, 8); for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down(val, 4); val += __shfl_down_sync(mask, val, offset);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
__syncthreads(); __syncthreads();
if (tidx == 0) { if (tidx == 0) {
sh_dw[t][tidy] += val; sh_dw[t][tidy] += val;
...@@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw, ...@@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw,
__shared__ real sh_x[BLOCK_H][BLOCK_W]; __shared__ real sh_x[BLOCK_H][BLOCK_W];
__shared__ real sh_dy[BLOCK_H][BLOCK_W]; __shared__ real sh_dy[BLOCK_H][BLOCK_W];
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int i = 0; i < numSeq; ++i) { for (int i = 0; i < numSeq; ++i) {
const int start = starts[i]; const int start = starts[i];
const int end = starts[i + 1]; const int end = starts[i + 1];
...@@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw, ...@@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw,
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
__syncthreads(); __syncthreads();
// warp size and blockDim.x is 32. // warp size and blockDim.x is 32.
val += __shfl_down(val, 16); for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down(val, 8); val += __shfl_down_sync(mask, val, offset);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
__syncthreads(); __syncthreads();
if (tidx == 0 && (gidx + tidy) < width) { if (tidx == 0 && (gidx + tidy) < width) {
...@@ -323,8 +328,8 @@ template <> ...@@ -323,8 +328,8 @@ template <>
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG, void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
const GpuMatrix& in, const GpuMatrix& in,
const GpuMatrix& filter, const GpuMatrix& filter,
GpuMatrix& inG, GpuMatrix& inG, // NOLINT
GpuMatrix& filterG, GpuMatrix& filterG, // NOLINT
const GpuIVector& seq) { const GpuIVector& seq) {
const size_t numSeq = seq.getSize() - 1; const size_t numSeq = seq.getSize() - 1;
const size_t contextLength = filter.getHeight(); const size_t contextLength = filter.getHeight();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册