diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index 786fe63e7580ce16b946d5049a490eed2c3c6ced..ae52849162ae4d78cc69ddbb98f58059f55683cb 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -84,12 +84,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel { } } out_dims[0] = out_first_dim; - ctx->SetOutputDim("Out", out_dims); } else { out_dims[0] = -1; - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); } + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 743e3bbc297c5da574c25ded69ef8d27325944a5..1bd73426522bcac608e54f979a5c049b2d2fa62b 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -24,123 +24,128 @@ namespace operators { using LoDTensor = framework::LoDTensor; template -__global__ void sequence_expand_kernel(const T* x_data, T* out_data, - const size_t* lod, - const size_t* out_offset, - size_t lod_size, size_t element_len, - size_t x_size) { - int bid_x = blockIdx.x; - if (bid_x > lod_size) return; - int repeats = lod[bid_x]; - int offset = out_offset[bid_x]; - for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) { - for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) { - out_data[(offset + tid_y) * element_len + tid_x] = - x_data[bid_x * element_len + tid_x]; +__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, + const size_t* ref_lod, + const size_t lod_size, + /* default=1, + the instance length*/ + const int x_item_length, T* out_data) { + constexpr int N = 1024; + __shared__ int mem[N]; + int offset = 0; + for (int i = 0; i < lod_size; ++i) { + mem[i] = offset; + if (i < lod_size - 1) { + offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); } } -} + __syncthreads(); -template -__global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data, - const size_t* lod, - const size_t* out_offset, - size_t lod_size, size_t element_len, - size_t dout_size, size_t dx_size) { - // reduce visit memory time. - // dout_shm = [0 - dout_size-1], dx_shm = [dout_size-1, dout_size + dx_size-1] - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && - threadIdx.y == 0) { - printf("lod_size=%ld, element_size=%ld, dout_size=%ld, dx_size=%ld\n", - lod_size, element_len, dout_size, dx_size); - } - extern __shared__ T shm[]; - T* dout_shm = shm; - T* dx_shm = &shm[dout_size]; - - // int idx = threadIdx.x + blockIdx.x * blockDim.x; - for (int idx = 0; idx < dout_size; ++idx) { - if (idx < dx_size) { - dx_shm[idx] = 0.0; - } - if (idx < dout_size) { - dout_shm[idx] = dout_data[idx]; + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + + int x_item_count = x_lod[bid + 1] - x_lod[bid]; + int repeats = ref_lod[bid + 1] - ref_lod[bid]; + int out_offset = mem[bid]; + int x_offset = x_lod[bid]; + for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { + for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < x_item_length; + tid_x += blockDim.x) { + out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length + + tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x]; + } } } +} - int bid_x = blockIdx.x; - if (bid_x > lod_size) return; - int repeats = lod[bid_x]; - int offset = out_offset[bid_x]; - if (threadIdx.x == 0) { - printf("repeats=%d, offset=%ld\n", repeats, offset); - } - for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) { - for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) { - T val = dout_shm[(offset + tid_y) * element_len + tid_x]; - platform::CudaAtomicAdd(&dx_shm[bid_x * element_len + tid_x], val); - int dx_idx = bid_x * element_len + tid_x; - int dout_idx = (offset + tid_y) * element_len + tid_x; - printf("dx_idx=%d, dout_idx=%d, dx_data=%f, dout_data=%f, val=%f \n", - dx_idx, dout_idx, dx_shm[dx_idx], dout_shm[dout_idx], val); +template +__global__ void sequence_expand_grad_kernel(const T* dout_data, + const size_t* ref_lod, + const size_t* dx_lod, + const size_t lod_size, + /* default=1, + the instance length*/ + const int x_item_length, + T* dx_data) { + // TODO(dzhwinter) : too many atomicAdd + // use shared memory to reduce memory visits + constexpr int N = 1024; + __shared__ int mem[N]; + int offset = 0; + for (int i = 0; i < lod_size; ++i) { + mem[i] = offset; + if (i < lod_size - 1) { + offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]); } } __syncthreads(); - // copy shared memory back to dx - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < dx_size; - idx += blockDim.x * gridDim.x) { - dx_data[idx] = dx_shm[idx]; + + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; + int repeats = ref_lod[bid + 1] - ref_lod[bid]; + int out_offset = mem[bid]; + int x_offset = dx_lod[bid]; + + for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { + for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < x_item_length; + tid_x += blockDim.x) { + platform::CudaAtomicAdd( + &dx_data[(x_offset + tid_y) * x_item_length + tid_x], + dout_data[(out_offset + tid_z * x_item_count + tid_y) * + x_item_length + + tid_x]); + } + } } } template struct SequenceExpandFunctor { - void operator()(const platform::CUDADeviceContext& context, - const LoDTensor& x, LoDTensor* out) { - auto x_dims = x.dims(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - auto lod = out->lod().back(); - framework::Vector out_lod; - for (size_t i = 0; i < lod.size() - 1; ++i) { - out_lod.push_back(lod[i + 1] - lod[i]); - } - - int thread_x = std::max(static_cast(element_len), 32); - int block_x = static_cast(out_lod.size()); - dim3 block_size(thread_x, 1024 / thread_x); + void operator()( + const platform::CUDADeviceContext& context, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out) { + int x_item_length = 1; + x_item_length = x.numel() / x.dims()[0]; + VLOG(0) << "x_item_length" << x_item_length; + int thread_x = std::max(static_cast(ref_lod.size()), 32); + int thread_y = std::max(1024 / thread_x, 16); + int thread_z = std::min(1024 / thread_x / thread_y, 16); + int block_x = static_cast(ref_lod.size()); + dim3 block_size(thread_x, thread_y, thread_z); dim3 grid_size(block_x, 1); + sequence_expand_kernel<<>>( - x.data(), out->mutable_data(context.GetPlace()), - out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()), - out_lod.size(), element_len, framework::product(x_dims)); + x.data(), x_lod.CUDAData(context.GetPlace()), + ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, + out->mutable_data(context.GetPlace())); } }; template struct SequenceExpandGradFunctor { void operator()(const platform::CUDADeviceContext& context, - const LoDTensor& x, const LoDTensor& out, - const LoDTensor& dout, LoDTensor* dx) { - auto x_dims = x.dims(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - auto lod = out.lod().back(); - framework::Vector out_lod; - for (size_t i = 0; i < lod.size() - 1; ++i) { - out_lod.push_back(lod[i + 1] - lod[i]); - } - size_t dout_size = framework::product(dout.dims()); - size_t dx_size = framework::product(dx->dims()); - - int thread_x = std::max(static_cast(element_len), 32); - dim3 block_size(thread_x, 1024 / thread_x); - int block_x = static_cast(out_lod.size()); + const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand based lod*/ + LoDTensor* dx) { + int x_item_length = 1; + x_item_length = framework::product(dx->dims()) / dx->dims()[0]; + + int thread_x = std::max(static_cast(ref_lod.size()), 32); + int thread_y = std::max(1024 / thread_x, 16); + int thread_z = std::min(1024 / thread_x / thread_y, 16); + int block_x = static_cast(ref_lod.size()); + dim3 block_size(thread_x, thread_y, thread_z); dim3 grid_size(block_x, 1); - sequence_expand_grad_kernel<<>>( - dout.data(), dx->mutable_data(context.GetPlace()), - out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()), - out_lod.size(), element_len, dout_size, dx_size); + sequence_expand_grad_kernel<<>>( + dout.data(), ref_lod.CUDAData(context.GetPlace()), + x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, + dx->mutable_data(context.GetPlace())); } }; diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 5cab367988dd26c6e0af77292bc6195c62997bfa..c55c3e215abdf91e3fe9b1bdb23823747865839a 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include // std::itoa +#include // std::iota +#include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/math/math_function.h" @@ -29,40 +31,42 @@ using EigenMatrix = framework::EigenMatrix; template struct SequenceExpandFunctor { - void operator()(const DeviceContext& ctx, const LoDTensor& x, LoDTensor* out); + void operator()( + const DeviceContext& ctx, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out); }; template struct SequenceExpandGradFunctor { - void operator()(const DeviceContext& ctx, const LoDTensor& x, - const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx); + void operator()( + const DeviceContext& ctx, const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* dx); }; template struct SequenceExpandFunctor { - void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, - LoDTensor* out) { - auto& out_lod = out->lod()[0]; - framework::Vector x_lod; - if (x.lod() == 1) { - x_lod = x.lod()[0]; - } else { - x_lod.reserve(out_lod.size()); - std::itoa(x_lod.begin(), x_lod.end(), 0); // fill 0 ~ out_lod.size()-1 - } + void operator()( + const platform::CPUDeviceContext& context, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out) { int out_offset = 0; auto& eigen_place = *context.eigen_device(); - for (size_t i = 1; i < out_lod.size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + for (size_t i = 1; i < ref_lod.size(); ++i) { + int repeat_num = ref_lod[i] - ref_lod[i - 1]; int x_start = x_lod[i - 1]; int x_end = x_lod[i]; int x_seq_len = x_end - x_start; if (repeat_num > 0) { - auto x_sub_tensor = x->Slice(x_start, x_end); + auto x_sub_tensor = x.Slice(x_start, x_end); x_sub_tensor.Resize({1, x_sub_tensor.numel()}); int out_start = out_offset; - if (x_lod.size() == 1) { - out_start = out_lod[0][out_offset]; + if (out->lod().size() == 1) { + out_start = out->lod()[0][out_offset]; } auto out_sub_tensor = out->Slice(out_start, out_start + x_seq_len * repeat_num); @@ -71,6 +75,7 @@ struct SequenceExpandFunctor { EigenMatrix::From(x_sub_tensor) .broadcast(Eigen::array({{repeat_num, 1}})); } + out_offset += repeat_num; } } }; @@ -96,13 +101,10 @@ class SequenceExpandKernel : public framework::OpKernel { return; } - auto& out_lod = *out->mutable_lod(); // x lod level is at most 1. - if (x_lod.size() == 0) { - out_lod = y_lod[ref_level]; - } else if (x_lod.size() == 1) { - out_lod.resize(1); - out_lod[0] = {0}; + framework::Vector out_lod; + if (x_lod.size() == 1) { + out_lod.push_back(0); int out_offset = 0; for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; @@ -110,14 +112,25 @@ class SequenceExpandKernel : public framework::OpKernel { int x_end = x_lod[0][i]; int x_seq_len = x_end - x_start; for (int j = 0; j < repeat_num; ++j) { - out_lod[0].push_back(out_lod[0].back() + x_seq_len); + out_lod.push_back(out_lod.back() + x_seq_len); out_offset++; } } + // write lod to out if x has lod + auto& ref_lod = *out->mutable_lod(); + ref_lod[0] = out_lod; + } + framework::Vector ref_x_lod; + if (x->lod().size() == 1) { + ref_x_lod = x->lod()[0]; + } else { + // x_lod doesn't has lod, use fake x lod, level = 0 + ref_x_lod.resize(x->dims()[0] + 1); + std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0); } - SequenceExpandFunctor functor; - functor(context.template device_context(), *x, out); + functor(context.template device_context(), *x, ref_x_lod, + y_lod[ref_level], out); } }; @@ -135,32 +148,29 @@ class SequenceExpandKernel : public framework::OpKernel { * */ template struct SequenceExpandGradFunctor { - void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, - const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { - auto& dev_ctx = context.template device_context(); - - math::SetConstant set_zero; - set_zero(dev_ctx, g_x, static_cast(0)); - - int g_out_offset = 0; - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + void operator()( + const platform::CPUDeviceContext& context, const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* dx) { + math::SetConstant set_zero; + set_zero(context, dx, static_cast(0)); + + int dout_offset = 0; + for (size_t i = 1; i < ref_lod.size(); ++i) { + int repeat_num = ref_lod[i] - ref_lod[i - 1]; if (repeat_num > 0) { - int x_start = i - 1; - int x_end = i; - if (x_lod.size() == 1) { - x_start = x_lod[0][i - 1]; - x_end = x_lod[0][i]; - } + int x_start = x_lod[i - 1]; + int x_end = x_lod[i]; int x_seq_len = x_end - x_start; - auto g_x_sub = g_x->Slice(x_start, x_end); - g_x_sub.Resize(flatten_to_1d(g_x_sub.dims())); - int g_out_end = g_out_offset + repeat_num * x_seq_len; - auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); - g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]}); - math::ColwiseSum col_sum; - col_sum(dev_ctx, g_out_sub, &g_x_sub); - g_out_offset += repeat_num * x_seq_len; + auto dx_sub = dx->Slice(x_start, x_end); + dx_sub.Resize(flatten_to_1d(dx_sub.dims())); + int dout_end = dout_offset + repeat_num * x_seq_len; + auto dout_sub = dout.Slice(dout_offset, dout_end); + dout_sub.Resize({repeat_num, dx_sub.dims()[0]}); + math::ColwiseSum col_sum; + col_sum(context, dout_sub, &dx_sub); + dout_offset += repeat_num * x_seq_len; } } } @@ -179,20 +189,26 @@ class SequenceExpandGradKernel : public framework::OpKernel { g_x->mutable_data(context.GetPlace()); g_x->set_lod(x->lod()); - auto& x_lod = x->lod(); auto& y_lod = y->lod(); - if (ref_level == -1) ref_level = y_lod.size() - 1; - // just copy the gradient if (y_lod[ref_level].size() <= 1) { framework::TensorCopy(*g_out, context.GetPlace(), g_x); return; } + framework::Vector ref_x_lod; + framework::Vector ref_lod = y_lod[ref_level]; + if (x->lod().size() == 1) { + ref_x_lod = x->lod()[0]; + } else { + // x_lod doesn't has lod, use fake x lod, level = 0 + ref_x_lod.resize(x->dims()[0] + 1); + std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0); + } SequenceExpandGradFunctor functor; - functor(context.template device_context(), *x, *y, *g_out, - g_x); + functor(context.template device_context(), *g_out, ref_x_lod, + ref_lod, g_x); } }; diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index d1cebc4ea2df9af39282ffcaecdc4707b0e24460..4c8ec1426c6e103498af544ea5928ec630707d46 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -19,14 +19,8 @@ from op_test import OpTest class TestSequenceExpand(OpTest): def set_data(self): - x = [i / 10.0 for i in range(3)] - y = [i / 10.0 for i in range(8)] - x_data = np.array(x).reshape(3, 1).astype('float32') - y_data = np.array(y).reshape(8, 1).astype('float32') - print(x_data) - print(y_data) - # x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') - # y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') y_lod = [[0, 1, 4, 8]] self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} @@ -53,8 +47,10 @@ class TestSequenceExpand(OpTest): x_len = x_idx[i] - x_idx[i - 1] if repeat_num > 0: x_sub = x_data[x_idx[i - 1]:x_idx[i], :] - x_sub = np.repeat(x_sub, repeat_num, axis=0) - out = np.vstack((out, x_sub)) + stacked_x_sub = x_sub + for r in range(repeat_num - 1): + stacked_x_sub = np.vstack((stacked_x_sub, x_sub)) + out = np.vstack((out, stacked_x_sub)) if x_lod is not None: for j in xrange(repeat_num): out_lod[0].append(out_lod[0][-1] + x_len) @@ -107,11 +103,11 @@ class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand): def set_data(self): - data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3] + data = np.random.uniform(0.1, 1, [5 * 2, 1]) x_data = np.array(data).reshape([5, 2]).astype('float32') x_lod = [[0, 2, 5]] - y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') - y_lod = [[0, 1, 2], [0, 1, 2]] + y_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_lod = [[0, 1, 3], [0, 1, 3]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}