diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index 786fe63e7580ce16b946d5049a490eed2c3c6ced..ae52849162ae4d78cc69ddbb98f58059f55683cb 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -84,12 +84,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel { } } out_dims[0] = out_first_dim; - ctx->SetOutputDim("Out", out_dims); } else { out_dims[0] = -1; - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); } + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index bb51bb2902eea797de3449fcb6c8b52b4f0e7fbf..c00765e5d59af068e5682b39ebace5f3d7a62250 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -12,8 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU +#include #include "paddle/fluid/operators/sequence_expand_op.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, + const size_t* ref_lod, + const size_t* offset, + const size_t lod_size, + /* default=1, + the instance length*/ + const int x_item_length, T* out_data) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + + int x_item_count = x_lod[bid + 1] - x_lod[bid]; + int repeats = ref_lod[bid + 1] - ref_lod[bid]; + int out_offset = static_cast(offset[bid]); + int x_offset = x_lod[bid]; + for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { + for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < x_item_length; + tid_x += blockDim.x) { + out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length + + tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x]; + } + } + } +} + +template +__global__ void sequence_expand_grad_kernel( + const T* dout_data, const size_t* ref_lod, const size_t* dx_lod, + const size_t* offset, const size_t lod_size, + /* default=1, + the instance length*/ + const int x_item_length, T* dx_data) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; + int repeats = ref_lod[bid + 1] - ref_lod[bid]; + int out_offset = static_cast(offset[bid]); + int x_offset = dx_lod[bid]; + + for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { + for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < x_item_length; + tid_x += blockDim.x) { + platform::CudaAtomicAdd( + &dx_data[(x_offset + tid_y) * x_item_length + tid_x], + dout_data[(out_offset + tid_z * x_item_count + tid_y) * + x_item_length + + tid_x]); + } + } + } +} + +void GetOutputOffset(const framework::Vector& x_lod, + const framework::Vector& ref_lod, + framework::Vector* out_offset) { + size_t offset = 0; + int lod_size = static_cast(x_lod.size()); + for (int i = 0; i < static_cast(x_lod.size()); ++i) { + (*out_offset)[i] = offset; + if (i < lod_size - 1) { + offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); + } + } +} + +template +struct SequenceExpandFunctor { + void operator()( + const platform::CUDADeviceContext& context, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out) { + int x_item_length = x.numel() / x.dims()[0]; + framework::Vector out_offset(x_lod.size()); + GetOutputOffset(x_lod, ref_lod, &out_offset); + + int thread_x = std::min(32, std::max(static_cast(ref_lod.size()), 16)); + int thread_y = 16; + int thread_z = 1024 / thread_x / thread_y; + int block_x = static_cast(ref_lod.size()); + dim3 block_size(thread_x, thread_y, thread_z); + dim3 grid_size(block_x, 1); + + sequence_expand_kernel<<>>( + x.data(), x_lod.CUDAData(context.GetPlace()), + ref_lod.CUDAData(context.GetPlace()), + out_offset.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, + out->mutable_data(context.GetPlace())); + } +}; + +template +struct SequenceExpandGradFunctor { + void operator()(const platform::CUDADeviceContext& context, + const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand based lod*/ + LoDTensor* dx) { + int x_item_length = framework::product(dx->dims()) / dx->dims()[0]; + framework::Vector out_offset(x_lod.size()); + GetOutputOffset(x_lod, ref_lod, &out_offset); + + int thread_x = std::min(32, std::max(static_cast(ref_lod.size()), 16)); + int thread_y = 16; + int thread_z = 1024 / thread_x / thread_y; + int block_x = static_cast(ref_lod.size()); + dim3 block_size(thread_x, thread_y, thread_z); + dim3 grid_size(block_x, 1); + sequence_expand_grad_kernel<<>>( + dout.data(), ref_lod.CUDAData(context.GetPlace()), + x_lod.CUDAData(context.GetPlace()), + out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, + dx->mutable_data(context.GetPlace())); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index db7d8bd6821fabd9714a160970558291ec47197f..d62c387c3eebf9df0ab532f4e891da006f239468 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include // std::iota #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" @@ -26,6 +27,57 @@ template using EigenMatrix = framework::EigenMatrix; +template +struct SequenceExpandFunctor { + void operator()( + const DeviceContext& ctx, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out); +}; + +template +struct SequenceExpandGradFunctor { + void operator()( + const DeviceContext& ctx, const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* dx); +}; + +template +struct SequenceExpandFunctor { + void operator()( + const platform::CPUDeviceContext& context, const LoDTensor& x, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* out) { + int out_offset = 0; + auto& eigen_place = *context.eigen_device(); + for (size_t i = 1; i < ref_lod.size(); ++i) { + int repeat_num = ref_lod[i] - ref_lod[i - 1]; + int x_start = x_lod[i - 1]; + int x_end = x_lod[i]; + int x_seq_len = x_end - x_start; + if (repeat_num > 0) { + auto x_sub_tensor = x.Slice(x_start, x_end); + x_sub_tensor.Resize({1, x_sub_tensor.numel()}); + int out_start = out_offset; + if (out->lod().size() == 1) { + out_start = out->lod()[0][out_offset]; + } + auto out_sub_tensor = + out->Slice(out_start, out_start + x_seq_len * repeat_num); + out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); + EigenMatrix::From(out_sub_tensor).device(eigen_place) = + EigenMatrix::From(x_sub_tensor) + .broadcast(Eigen::array({{repeat_num, 1}})); + } + out_offset += repeat_num; + } + } +}; + template class SequenceExpandKernel : public framework::OpKernel { public: @@ -47,45 +99,36 @@ class SequenceExpandKernel : public framework::OpKernel { return; } - auto& out_lod = *out->mutable_lod(); + // x lod level is at most 1. + framework::Vector out_lod; if (x_lod.size() == 1) { - out_lod.resize(1); - out_lod[0] = {0}; - } - - int out_offset = 0; - auto& eigen_place = - *context.template device_context().eigen_device(); - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; - int x_start = i - 1; - int x_end = i; - if (x_lod.size() == 1) { - x_start = x_lod[0][i - 1]; - x_end = x_lod[0][i]; - } - int x_seq_len = x_end - x_start; - if (repeat_num > 0) { - auto x_sub_tensor = x->Slice(x_start, x_end); - x_sub_tensor.Resize({1, x_sub_tensor.numel()}); - int out_start = out_offset; - if (x_lod.size() == 1) { - out_start = out_lod[0][out_offset]; - } - auto out_sub_tensor = - out->Slice(out_start, out_start + x_seq_len * repeat_num); - out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); - EigenMatrix::From(out_sub_tensor).device(eigen_place) = - EigenMatrix::From(x_sub_tensor) - .broadcast(Eigen::array({{repeat_num, 1}})); - } - for (int j = 0; j < repeat_num; ++j) { - if (x_lod.size() == 1) { - out_lod[0].push_back(out_lod[0].back() + x_seq_len); + out_lod.push_back(0); + int out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = x_lod[0][i - 1]; + int x_end = x_lod[0][i]; + int x_seq_len = x_end - x_start; + for (int j = 0; j < repeat_num; ++j) { + out_lod.push_back(out_lod.back() + x_seq_len); + out_offset++; } - out_offset++; } + // write lod to out if x has lod + auto& ref_lod = *out->mutable_lod(); + ref_lod[0] = out_lod; } + framework::Vector ref_x_lod; + if (x->lod().size() == 1) { + ref_x_lod = x->lod()[0]; + } else { + // x_lod doesn't has lod, use fake x lod, level = 0 + ref_x_lod.resize(x->dims()[0] + 1); + std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0); + } + SequenceExpandFunctor functor; + functor(context.template device_context(), *x, ref_x_lod, + y_lod[ref_level], out); } }; @@ -101,6 +144,36 @@ class SequenceExpandKernel : public framework::OpKernel { * Grad(X).lod = Input(X).lod * * */ +template +struct SequenceExpandGradFunctor { + void operator()( + const platform::CPUDeviceContext& context, const LoDTensor& dout, + const framework::Vector& x_lod, /*expand source lod*/ + const framework::Vector& ref_lod, /*expand referenced lod*/ + LoDTensor* dx) { + math::SetConstant set_zero; + set_zero(context, dx, static_cast(0)); + + int dout_offset = 0; + for (size_t i = 1; i < ref_lod.size(); ++i) { + int repeat_num = ref_lod[i] - ref_lod[i - 1]; + if (repeat_num > 0) { + int x_start = x_lod[i - 1]; + int x_end = x_lod[i]; + int x_seq_len = x_end - x_start; + auto dx_sub = dx->Slice(x_start, x_end); + dx_sub.Resize(flatten_to_1d(dx_sub.dims())); + int dout_end = dout_offset + repeat_num * x_seq_len; + auto dout_sub = dout.Slice(dout_offset, dout_end); + dout_sub.Resize({repeat_num, dx_sub.dims()[0]}); + math::ColwiseSum col_sum; + col_sum(context, dout_sub, &dx_sub); + dout_offset += repeat_num * x_seq_len; + } + } + } +}; + template class SequenceExpandGradKernel : public framework::OpKernel { public: @@ -114,43 +187,26 @@ class SequenceExpandGradKernel : public framework::OpKernel { g_x->mutable_data(context.GetPlace()); g_x->set_lod(x->lod()); - auto& x_lod = x->lod(); auto& y_lod = y->lod(); - if (ref_level == -1) ref_level = y_lod.size() - 1; - // just copy the gradient if (y_lod[ref_level].size() <= 1) { framework::TensorCopy(*g_out, context.GetPlace(), g_x); return; } - auto& dev_ctx = context.template device_context(); - - math::SetConstant set_zero; - set_zero(dev_ctx, g_x, static_cast(0)); - - int g_out_offset = 0; - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; - if (repeat_num > 0) { - int x_start = i - 1; - int x_end = i; - if (x_lod.size() == 1) { - x_start = x_lod[0][i - 1]; - x_end = x_lod[0][i]; - } - int x_seq_len = x_end - x_start; - auto g_x_sub = g_x->Slice(x_start, x_end); - g_x_sub.Resize(flatten_to_1d(g_x_sub.dims())); - int g_out_end = g_out_offset + repeat_num * x_seq_len; - auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); - g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]}); - math::ColwiseSum col_sum; - col_sum(dev_ctx, g_out_sub, &g_x_sub); - g_out_offset += repeat_num * x_seq_len; - } + framework::Vector ref_x_lod; + framework::Vector ref_lod = y_lod[ref_level]; + if (x->lod().size() == 1) { + ref_x_lod = x->lod()[0]; + } else { + // x_lod doesn't has lod, use fake x lod, level = 0 + ref_x_lod.resize(x->dims()[0] + 1); + std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0); } + SequenceExpandGradFunctor functor; + functor(context.template device_context(), *g_out, ref_x_lod, + ref_lod, g_x); } }; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 3bd24c98a22b5db9833a312f481ed74c3d26f0ad..356c3e64b3d03b520a1bec5b5e0174e1d8ee23e8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -34,6 +34,8 @@ function(py_test_modules TARGET_NAME) endif() endfunction() +list(REMOVE_ITEM TEST_OPS test_sequence_expand) + # test time consuming OPs in a separate process for expliot parallism list(REMOVE_ITEM TEST_OPS test_parallel_executor) list(REMOVE_ITEM TEST_OPS test_warpctc_op) @@ -70,6 +72,8 @@ else() endforeach(TEST_OP) endif(WITH_FAST_BUNDLE_TEST) +# +py_test_modules(test_sequence_expand MODULES test_sequence_expand) # tests with high overhead py_test_modules(test_parallel_executor MODULES test_parallel_executor) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR}) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index 7feb509c4d6f5768552fc2515081f7e68f420967..4c8ec1426c6e103498af544ea5928ec630707d46 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -47,8 +47,10 @@ class TestSequenceExpand(OpTest): x_len = x_idx[i] - x_idx[i - 1] if repeat_num > 0: x_sub = x_data[x_idx[i - 1]:x_idx[i], :] - x_sub = np.repeat(x_sub, repeat_num, axis=0) - out = np.vstack((out, x_sub)) + stacked_x_sub = x_sub + for r in range(repeat_num - 1): + stacked_x_sub = np.vstack((stacked_x_sub, x_sub)) + out = np.vstack((out, stacked_x_sub)) if x_lod is not None: for j in xrange(repeat_num): out_lod[0].append(out_lod[0][-1] + x_len) @@ -101,11 +103,11 @@ class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand): def set_data(self): - data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3] + data = np.random.uniform(0.1, 1, [5 * 2, 1]) x_data = np.array(data).reshape([5, 2]).astype('float32') x_lod = [[0, 2, 5]] - y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') - y_lod = [[0, 1, 2], [0, 1, 2]] + y_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_lod = [[0, 1, 3], [0, 1, 3]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}