diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7155d5ef2febc20aaa684c04a7a59f781857c9e5..5125072ddd3d59066e4616e51597dbf0203a6efb 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -44,7 +44,7 @@ struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) : prog_(prog), block_id_(block_id) {} - const framework::ProgramDesc& prog_; + const framework::ProgramDesc prog_; size_t block_id_; std::vector> ops_; }; diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index cae0a6928455b425a78294b7507b1003ad198dbc..bf453ca7e8ea39a3f32d10d6d5527fdc9c180ad4 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU +#include +#include #include "paddle/fluid/operators/sequence_expand_op.h" +#include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -22,47 +25,71 @@ using LoDTensor = framework::LoDTensor; template __global__ void sequence_expand_kernel(const T* x_data, T* out_data, - const size_t* lod, size_t lod_size, - size_t element_len) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < static_cast(lod_size - 1); - tid_x += blockDim.x * gridDim.x) { - int scale = lod[tid_x + 1] - lod[tid_x]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) { - int tid_z = blockIdx.z * blockDim.z + threadIdx.z; - int item_start = tid_x / element_len; - for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { - out_data[item_start * scale + tid_z] = x_data[item_start + tid_z]; - } + const size_t* lod, + const size_t* out_offset, + size_t lod_size, size_t element_len, + size_t x_size) { + int bid_x = blockIdx.x; + if (bid_x > lod_size) return; + int repeats = lod[bid_x]; + int offset = out_offset[bid_x]; + for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) { + out_data[(offset + tid_y) * element_len + tid_x] = + x_data[bid_x * element_len + tid_x]; } } } template __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data, - const size_t* lod, size_t lod_size, - size_t element_len, - size_t dout_size) { + const size_t* lod, + const size_t* out_offset, + size_t lod_size, size_t element_len, + size_t dout_size, size_t dx_size) { + // reduce visit memory time. + // dout_shm = [0 - dout_size-1], dx_shm = [dout_size-1, dout_size + dx_size-1] + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && + threadIdx.y == 0) { + printf("lod_size=%ld, element_size=%ld, dout_size=%ld, dx_size=%ld\n", + lod_size, element_len, dout_size, dx_size); + } extern __shared__ T shm[]; - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < static_cast(lod_size - 1); - tid_x += blockDim.x * gridDim.x) { - int scale = lod[tid_x + 1] - lod[tid_x]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) { - int tid_z = blockIdx.z * blockDim.z + threadIdx.z; - int item_start = tid_x / element_len; - for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { - shm[item_start + tid_z] += dout_data[item_start * scale + tid_z]; - } + T* dout_shm = shm; + T* dx_shm = &shm[dout_size]; + + // int idx = threadIdx.x + blockIdx.x * blockDim.x; + for (int idx = 0; idx < dout_size; ++idx) { + if (idx < dx_size) { + dx_shm[idx] = 0.0; + } + if (idx < dout_size) { + dout_shm[idx] = dout_data[idx]; + } + } + + int bid_x = blockIdx.x; + if (bid_x > lod_size) return; + int repeats = lod[bid_x]; + int offset = out_offset[bid_x]; + if (threadIdx.x == 0) { + printf("repeats=%d, offset=%ld\n", repeats, offset); + } + for (int tid_y = threadIdx.y; tid_y < repeats; tid_y += blockDim.y) { + for (int tid_x = threadIdx.x; tid_x < element_len; tid_x += blockDim.x) { + T val = dout_shm[(offset + tid_y) * element_len + tid_x]; + platform::CudaAtomicAdd(&dx_shm[bid_x * element_len + tid_x], val); + int dx_idx = bid_x * element_len + tid_x; + int dout_idx = (offset + tid_y) * element_len + tid_x; + printf("dx_idx=%d, dout_idx=%d, dx_data=%f, dout_data=%f, val=%f \n", + dx_idx, dout_idx, dx_shm[dx_idx], dout_shm[dout_idx], val); } } - // synchronize before write to dx __syncthreads(); - for (int idx = blockDim.x * blockIdx.x + threadIdx.x; - idx < static_cast(dout_size); idx += blockDim.x * gridDim.x) { - dx_data[idx] = shm[idx]; + // copy shared memory back to dx + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < dx_size; + idx += blockDim.x * gridDim.x) { + dx_data[idx] = dx_shm[idx]; } } @@ -72,15 +99,20 @@ struct SequenceExpandFunctor { const LoDTensor& x, LoDTensor* out) { auto x_dims = x.dims(); size_t element_len = framework::product(x_dims) / x_dims[0]; - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); + auto lod = out->lod().back(); + framework::Vector out_lod; + for (size_t i = 0; i < lod.size() - 1; ++i) { + out_lod.push_back(lod[i + 1] - lod[i]); + } - dim3 block_size(16, 32, element_len); - dim3 grid_size(10, 10); + int thread_x = std::max(static_cast(element_len), 32); + int block_x = static_cast(out_lod.size()); + dim3 block_size(thread_x, 1024 / thread_x); + dim3 grid_size(block_x, 1); sequence_expand_kernel<<>>( x.data(), out->mutable_data(context.GetPlace()), - out_starts.CUDAData(context.GetPlace()), out_starts.size(), - element_len); + out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()), + out_lod.size(), element_len, framework::product(x_dims)); } }; @@ -91,16 +123,24 @@ struct SequenceExpandGradFunctor { const LoDTensor& dout, LoDTensor* dx) { auto x_dims = x.dims(); size_t element_len = framework::product(x_dims) / x_dims[0]; - auto out_starts = out.lod().back(); + auto lod = out.lod().back(); + framework::Vector out_lod; + for (size_t i = 0; i < lod.size() - 1; ++i) { + out_lod.push_back(lod[i + 1] - lod[i]); + } + size_t dout_size = framework::product(dout.dims()); + size_t dx_size = framework::product(dx->dims()); - dim3 block_size(16, 32, element_len); - dim3 grid_size(10, 10); - size_t out_size = framework::product(dx->dims()); - sequence_expand_grad_kernel<<(element_len), 32); + dim3 block_size(thread_x, 1024 / thread_x); + int block_x = static_cast(out_lod.size()); + dim3 grid_size(block_x, 1); + sequence_expand_grad_kernel<<>>( dout.data(), dx->mutable_data(context.GetPlace()), - out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len, - out_size); + out_lod.CUDAData(context.GetPlace()), lod.CUDAData(context.GetPlace()), + out_lod.size(), element_len, dout_size, dx_size); } }; diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 8393f7827b1c7d361ebea72f2cfc6033268772f0..555f188abb9b87accd0383ca4a1a29718eea967c 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,6 +362,9 @@ class OpTest(unittest.TestCase): for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): abs_a = np.abs(a) abs_a[abs_a < 1e-3] = 1 + print("actual", a) + print("*****") + print("expected", b) diff_mat = np.abs(a - b) / abs_a max_diff = np.max(diff_mat) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index 957fa5d2c4a795cfd01047c1b7845674e4c1d549..f984127b4d64f104002e9be55170f73ae4350582 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -19,8 +19,14 @@ from op_test import OpTest class TestSequenceExpand(OpTest): def set_data(self): - x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') - y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + x = [i / 10.0 for i in range(3)] + y = [i / 10.0 for i in range(8)] + x_data = np.array(x).reshape(3, 1).astype('float32') + y_data = np.array(y).reshape(8, 1).astype('float32') + print(x_data) + print(y_data) + # x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + # y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') y_lod = [[0, 1, 4, 8]] self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} @@ -45,47 +51,43 @@ class TestSequenceExpand(OpTest): def test_check_grad(self): self.check_grad(["X"], "Out") - -class TestSequenceExpandCase1(TestSequenceExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') - x_lod = [[0, 2, 5]] - y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') - y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} - - -class TestSequenceExpandCase2(TestSequenceExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') - x_lod = [[0, 1]] - y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') - y_lod = [[0, 2]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} - - -class TestSequenceExpandCase3(TestSequenceExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') - x_lod = [[0, 1, 2, 3, 4]] - y_data = np.random.uniform(0.1, 1, [6, 1]).astype('float32') - y_lod = [[0, 2, 4, 4, 6]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} - - -class TestSequenceExpandCase4(TestSequenceExpand): - def set_data(self): - x_data = np.array( - [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( - [2, 5]).astype('float32') - x_lod = [[ - 0, - 1, - 2, - ]] - y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') - y_lod = [[0, 1, 2], [0, 1, 2]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + # class TestSequenceExpandCase1(TestSequenceExpand): + # def set_data(self): + # x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') + # x_lod = [[0, 2, 5]] + # y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') + # y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] + # self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + # class TestSequenceExpandCase2(TestSequenceExpand): + # def set_data(self): + # x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') + # x_lod = [[0, 1]] + # y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') + # y_lod = [[0, 2]] + # self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + # class TestSequenceExpandCase3(TestSequenceExpand): + # def set_data(self): + # x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + # x_lod = [[0, 1, 2, 3, 4]] + # y_data = np.random.uniform(0.1, 1, [6, 1]).astype('float32') + # y_lod = [[0, 2, 4, 4, 6]] + # self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + # class TestSequenceExpandCase4(TestSequenceExpand): + # def set_data(self): + # x_data = np.array( + # [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( + # [2, 5]).astype('float32') + # x_lod = [[ + # 0, + # 1, + # 2, + # ]] + # y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') + # y_lod = [[0, 1, 2], [0, 1, 2]] + # self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} if __name__ == '__main__':