diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index a1388f20dc5203867134b952d82fea7f1c87337f..7ca5ba3289b26f9b01774b1ab0e85f075c4cfc90 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -13,11 +13,334 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cum_op.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; +using Tensor = paddle::framework::Tensor; +using LoDTensor = paddle::framework::LoDTensor; + +namespace paddle { +namespace operators { + +template +__global__ void OuterScan(const T* in, T* out, int inner_dim_size, + int outer_dim_size, int scan_dim_size, bool exclusive, + bool reverse) { + int id = blockIdx.y * blockDim.x + threadIdx.x; + + for (int outer_index = blockIdx.x; outer_index < outer_dim_size; + outer_index += gridDim.x) { + for (int inner_index = blockIdx.y * blockDim.x + threadIdx.x; + inner_index < inner_dim_size; inner_index += gridDim.y * blockDim.x) { + int scan_index_init = 0; + int forward_direction = 1; + int src_index = + outer_index * scan_dim_size * inner_dim_size + inner_index; + int dst_index = + outer_index * scan_dim_size * inner_dim_size + inner_index; + if (reverse) { + src_index = src_index + (scan_dim_size - 1) * inner_dim_size; + dst_index = dst_index + (scan_dim_size - 1) * inner_dim_size; + forward_direction = -1; + } + if (exclusive) { + scan_index_init = 1; + out[dst_index] = 0; + dst_index = dst_index + (forward_direction * inner_dim_size); + } + T acc = 0; + + for (int scan_index = scan_index_init; scan_index < scan_dim_size; + ++scan_index) { + acc = in[src_index] + acc; + out[dst_index] = acc; + src_index += (forward_direction * inner_dim_size); + dst_index += (forward_direction * inner_dim_size); + } + } + } +} + +// inclusive scan +template +__global__ void InnerMostDimInclusiveScan(const T* in, T* out, + int inner_dim_size, + int outer_dim_size, int scan_dim_size, + bool reverse) { + __shared__ T share_data[num_threads_y][num_threads_x * 2]; + T* share_row = share_data[threadIdx.y]; + int forward_direction = 1; + if (reverse) forward_direction = -1; + + for (int block_row = blockIdx.x * blockDim.y; block_row < outer_dim_size; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + T acc = 0; + const T* row_src = in + row * scan_dim_size; + T* row_dst = out + row * scan_dim_size; + int block_col = 0; + bool loop_condition = (block_col < scan_dim_size); + if (reverse) { + loop_condition = (block_col >= 0); + block_col = scan_dim_size - 1; + } + while (loop_condition) { + // Load data into share memory(two value per thread) + int col1 = block_col + threadIdx.x * forward_direction; + int col2 = block_col + (num_threads_x + threadIdx.x) * forward_direction; + if (row < outer_dim_size) { + if (col1 < scan_dim_size && col1 >= 0) { + share_row[threadIdx.x] = row_src[col1]; + } else { + share_row[threadIdx.x] = 0; + } + + if (col2 < scan_dim_size && col2 >= 0) { + share_row[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + share_row[num_threads_x + threadIdx.x] = 0; + } + + // Add the previous block acc to the result + if (threadIdx.x == 0) { + share_row[0] = share_row[0] + acc; + } + } + __syncthreads(); + + // Up-Sweep + for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + if (row < outer_dim_size && threadIdx.x < s) { + unsigned offset = (2 * threadIdx.x + 1) * d - 1; + share_row[offset + d] = share_row[offset] + share_row[offset + d]; + } + __syncthreads(); + } + // Down-Sweep + for (unsigned s = 2, d = blockDim.x / 2; d >= 1; s <<= 1, d >>= 1) { + if (row < outer_dim_size && threadIdx.x < s - 1) { + unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + share_row[offset + d] = share_row[offset] + share_row[offset + d]; + } + __syncthreads(); + } + + // Write to the output + if (row < outer_dim_size) { + if (col1 < scan_dim_size && col1 >= 0) + row_dst[col1] = share_row[threadIdx.x]; + if (col2 < scan_dim_size && col2 >= 0) + row_dst[col2] = share_row[num_threads_x + threadIdx.x]; + } + acc = share_row[2 * num_threads_x - 1]; + __syncthreads(); + block_col += 2 * num_threads_x * forward_direction; + if (reverse) + loop_condition = (block_col >= 0); + else + loop_condition = (block_col < scan_dim_size); + } + } +} + +// exclusive block scan and store block sum for large scan +template +__global__ void InnerMostDimExclusiveScan(const T* in, T* out, T* sum_data, + int inner_dim_size, + int outer_dim_size, int scan_dim_size, + int two_power, bool reverse) { + // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory + extern __shared__ __align__(sizeof(T)) unsigned char raw_tmp[]; + T* share_tmp = reinterpret_cast(raw_tmp); + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + int block_scan_size = blockDim.x * 2; + int remain = scan_dim_size % (2 * blockDim.x); + if (block_id == gridDim.x - 1 && remain != 0) block_scan_size = remain; + int col1 = thread_id; + int col2 = thread_id + (block_scan_size) / 2; + int index1 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col1; + int index2 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col2; + if (reverse) { + index1 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - + (block_id * blockDim.x * 2 + col1); + index2 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - + (block_id * blockDim.x * 2 + col2); + } + int sum_index = blockIdx.y * gridDim.x + block_id; + if (thread_id < block_scan_size) { + share_tmp[col1 + (col1 >> 5)] = in[index1]; + share_tmp[col2 + (col2 >> 5)] = in[index2]; + } else { + share_tmp[col1 + (col1 >> 5)] = 0; + share_tmp[col2 + (col2 >> 5)] = 0; + } + + // Up-Sweep + int offset = 1; + for (int d = (two_power / 2); d > 0; d >>= 1) { + __syncthreads(); + if (thread_id < d) { + int tmp_index1 = offset * (2 * thread_id + 1) - 1; + int tmp_index2 = offset * (2 * thread_id + 2) - 1; + tmp_index1 = tmp_index1 + (tmp_index1 >> 5); + tmp_index2 = tmp_index2 + (tmp_index2 >> 5); + + share_tmp[tmp_index2] += share_tmp[tmp_index1]; + } + offset *= 2; + } + __syncthreads(); + + if (thread_id == 0) { + int tmp_index = (two_power - 1) + ((two_power - 1) >> 5); + sum_data[sum_index] = share_tmp[tmp_index]; + share_tmp[tmp_index] = 0; + } -REGISTER_OP_CUDA_KERNEL(cumsum, ops::CumKernel>, - ops::CumKernel>, - ops::CumKernel>, - ops::CumKernel>); + // Down Sweep + for (int d = 1; d < two_power; d *= 2) { + offset >>= 1; + __syncthreads(); + if (thread_id < d) { + int tmp_index1 = offset * (2 * thread_id + 1) - 1; + int tmp_index2 = offset * (2 * thread_id + 2) - 1; + tmp_index1 = tmp_index1 + (tmp_index1 >> 5); + tmp_index2 = tmp_index2 + (tmp_index2 >> 5); + + T tmp = share_tmp[tmp_index1]; + share_tmp[tmp_index1] = share_tmp[tmp_index2]; + share_tmp[tmp_index2] += tmp; + } + } + + __syncthreads(); + + if (col1 < block_scan_size) out[index1] = share_tmp[col1 + (col1 >> 5)]; + if (col2 < block_scan_size) out[index2] = share_tmp[col2 + (col2 >> 5)]; +} + +// for large scan_dim_size array we need to add for correct result +template +__global__ void AddBlockScan(T* result, T* sum, int size, int scan_dim_size, + int sum_size, bool reverse) { + int idx = threadIdx.x + blockDim.x * (blockIdx.x + blockIdx.y * gridDim.x); + int block_id_start = blockIdx.y * sum_size; + int block_id_end = blockIdx.x + blockIdx.y * sum_size; + int block_id = blockIdx.x; + int thread_id = threadIdx.x; + + int col = block_id * blockDim.x + thread_id + size; + int index = blockIdx.y * (scan_dim_size) + col; + if (reverse) { + index = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - col; + } + + if (col >= scan_dim_size || col < 0) return; + for (int i = block_id_start; i <= block_id_end; i++) { + result[index] += sum[i]; + } +} + +template +class CumCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + + int axis = context.Attr("axis"); + bool exclusive = context.Attr("exclusive"); + bool reverse = context.Attr("reverse"); + auto in_dims = in->dims(); + auto size = in->numel(); + + if (axis == -1) { + axis = in_dims.size() - 1; + } + PADDLE_ENFORCE_LT( + axis, in_dims.size(), + platform::errors::InvalidArgument("axis(%d) should be less than the " + "dimension(%d) of the input tensor.", + axis, in_dims.size())); + + int scan_dim_size = in_dims[axis]; + bool optimize_condition = (axis == (in_dims.size() - 1)) ? true : false; + int outer_dim_size = 1; + int inner_dim_size = 1; + // treat all dim index < axis as outer_dim_size + for (size_t i = 0; i < axis; i++) { + outer_dim_size *= in_dims[i]; + } + // treat all dim index > axis as innner_dim_size + for (size_t i = axis + 1; i < in_dims.size(); i++) { + inner_dim_size *= in_dims[i]; + } + + T* out_data = out->mutable_data(context.GetPlace()); + const T* in_data = in->data(); + + auto& dev_ctx = context.template device_context(); + if (optimize_condition) { + auto nextPowerOfTwo = [](int x) -> int { + int ret = 1; + while (ret < x) ret = ret * 2; + return ret; + }; + if (exclusive) { + int element_per_block = nextPowerOfTwo(scan_dim_size) / 2; + if (element_per_block > 512 || element_per_block < 32) { + element_per_block = 64; + } + int two_power = element_per_block * 2; + dim3 block(element_per_block); + dim3 grid(((scan_dim_size + 1) / 2 + block.x - 1) / block.x, + outer_dim_size); + int offset_size = (element_per_block * 2) >> 5; + int share_mem_size = (element_per_block * 2 + offset_size) * sizeof(T); + Tensor scan_sum; + paddle::framework::DDim dims{ + ((scan_dim_size + 1) / 2 + block.x - 1) / block.x, outer_dim_size}; + scan_sum.Resize(dims); + T* sum_data = scan_sum.mutable_data(context.GetPlace()); + InnerMostDimExclusiveScan< + T><<>>( + in_data, out_data, sum_data, inner_dim_size, outer_dim_size, + scan_dim_size, two_power, reverse); + // for large scan array we need to do add for correct result + int element_size = element_per_block * 2; + if (scan_dim_size > element_size) { + dim3 sum_block(element_per_block * 2); + dim3 sum_grid((scan_dim_size - element_size + block.x - 1) / block.x, + outer_dim_size); + int sum_size = ((scan_dim_size + 1) / 2 + block.x - 1) / block.x; + AddBlockScan<<>>( + out_data, sum_data, element_size, scan_dim_size, sum_size, + reverse); + } + + } else { + dim3 block(32, 16); + dim3 grid((outer_dim_size + block.y - 1) / block.y); + InnerMostDimInclusiveScan<<>>( + in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size, + reverse); + } + } else { + dim3 block(std::min(512, inner_dim_size)); + dim3 grid(outer_dim_size, (inner_dim_size + block.x - 1) / block.x); + OuterScan<<>>( + in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size, + exclusive, reverse); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + cumsum, ops::CumCUDAKernel, + ops::CumCUDAKernel, + ops::CumCUDAKernel, + ops::CumCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index dc023df4ff0782e8362accb046ace5f333bcee71..a1a80bfdb549fe509171d4ed3d320547aa5aec51 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -108,24 +108,108 @@ class TestSumOp7(OpTest): self.check_grad(['X'], 'Out') -class TestSumOp8(OpTest): +class TestSumOpExclusive1(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((5, 6, 4)).astype("float64") + a = np.random.random((4, 5, 65)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( (np.zeros( - (5, 6, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + (4, 5, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), axis=2) } def test_check_output(self): self.check_output() - def test_check_grad(self): - self.check_grad(['X'], 'Out') + +class TestSumOpExclusive2(OpTest): + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, "exclusive": True} + a = np.random.random((1, 1, 888)).astype("float64") + self.inputs = {'X': a} + self.outputs = { + 'Out': np.concatenate( + (np.zeros( + (1, 1, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + axis=2) + } + + def test_check_output(self): + self.check_output() + + +class TestSumOpExclusive3(OpTest): + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, "exclusive": True} + a = np.random.random((4, 5, 888)).astype("float32") + self.inputs = {'X': a} + self.outputs = { + 'Out': np.concatenate( + (np.zeros( + (4, 5, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + axis=2) + } + + def test_check_output(self): + self.check_output() + + +class TestSumOpExclusive4(OpTest): + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, "exclusive": True} + a = np.random.random((1, 1, 3049)).astype("float64") + self.inputs = {'X': a} + self.outputs = { + 'Out': np.concatenate( + (np.zeros( + (1, 1, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + axis=2) + } + + def test_check_output(self): + self.check_output() + + +class TestSumOpExclusive5(OpTest): + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, "exclusive": True} + a = np.random.random((4, 5, 3096)).astype("float64") + self.inputs = {'X': a} + self.outputs = { + 'Out': np.concatenate( + (np.zeros( + (4, 5, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + axis=2) + } + + def test_check_output(self): + self.check_output() + + +class TestSumOpReverseExclusive(OpTest): + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, 'reverse': True, "exclusive": True} + a = np.random.random((4, 5, 6)).astype("float64") + self.inputs = {'X': a} + a = np.flip(a, axis=2) + self.outputs = { + 'Out': np.concatenate( + (np.flip( + a[:, :, :-1].cumsum(axis=2), axis=2), np.zeros( + (4, 5, 1), dtype=np.float64)), + axis=2) + } + + def test_check_output(self): + self.check_output() class BadInputTest(unittest.TestCase): @@ -133,7 +217,7 @@ class BadInputTest(unittest.TestCase): with fluid.program_guard(fluid.Program()): def test_bad_x(): - data = [1, 2, 3] + data = [1, 2, 4] result = fluid.layers.cumsum(data, axis=0) self.assertRaises(TypeError, test_bad_x)