diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index a72a0675c54d49855f556268f55c4b3fee9f9543..054c333b8b5a97f7475fdc771c3695b4f8c29a4d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -167,7 +167,7 @@ paddle.fluid.layers.nce (ArgSpec(args=['input', 'label', 'num_total_classes', 's paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)), ('document', 'd4435a63d34203339831ee6a86ef9242')) paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', 'b83e7dfa81059b39bb137922dc914f50')) paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', '1270395ce97a4e1b556104abbb14f096')) -paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '17485788fffe4e2d36dc58c2ac8d174e')) +paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d8a1c8b686b55631ba1b77805e4eacf')) paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '79797f827d89ae72c77960e9696883a9')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '96b24820e8863d6044d5be4eaaddb9fd')) diff --git a/paddle/fluid/operators/row_conv_op.cc b/paddle/fluid/operators/row_conv_op.cc index 7e9611679ba9a988f40973aaa37f04bcfa48f1ad..1645c47e9660faa4d211c1fb05167a582e0fbc46 100644 --- a/paddle/fluid/operators/row_conv_op.cc +++ b/paddle/fluid/operators/row_conv_op.cc @@ -1,5 +1,4 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -43,13 +42,7 @@ class RowConvOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto filter_dims = ctx->GetInputDim("Filter"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); - if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) { - PADDLE_ENFORCE_EQ( - x_dims[1], filter_dims[1], - "The 2nd dimension of Input(X) and Input(Filter) should be same."); - } ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", "Out"); @@ -84,11 +77,12 @@ class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "the input(X) is a LodTensor, which supports " + "the input(X) is a LodTensor or tensor, LodTensor(X) supports " "variable time-length input sequences. The underlying tensor " "in this LoDTensor is a matrix with shape (T x N), where T " "is the total time steps in this mini-batch and N is the input " - "data dimension."); + "data dimension. the shape of Tensor input(X) has shape " + "(B x T x N), B is batch size;"); AddInput("Filter", "the input(Filter) is a learnable parameter. It " "is a 2-D tensor with shape (future_context x N), where, " @@ -152,8 +146,26 @@ class RowConvKernel out->mutable_data(context.GetPlace()); - auto batch_indices = x->lod()[0]; - auto input_dim = x->dims()[1]; // 'in' is of size T x N + bool is_tensor = x->lod().empty(); + int batch_size = 0; + if (is_tensor) { + batch_size = x->dims()[0]; + } else { + batch_size = x->lod()[0].size() - 1; + } + framework::Vector batch_indices(batch_size + 1); + int input_dim = 0; + int timesteps = 0; + if (is_tensor) { + for (int i = 0; i < batch_size + 1; i++) { + batch_indices[i] = i; + } + input_dim = x->dims()[2]; + timesteps = x->dims()[1]; + } else { + batch_indices = x->lod()[0]; + input_dim = x->dims()[1]; + } size_t num_sequence = batch_indices.size() - 1; auto future_context = filter->dims()[0]; @@ -162,11 +174,23 @@ class RowConvKernel for (size_t i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); - int current_timesteps = end - start; + int current_timesteps = 0; + if (is_tensor) { + current_timesteps = timesteps; + } else { + current_timesteps = end - start; + } + // int current_timesteps = end - start; Tensor cur_input_sequence = x->Slice(start, end); // Current input sequence + cur_input_sequence = + cur_input_sequence.Resize({current_timesteps, input_dim}); + Tensor cur_output_sequence = out->Slice(start, end); // Current output sequence + cur_output_sequence = + cur_output_sequence.Resize({current_timesteps, input_dim}); + auto cip_seq = EigenMatrix::From(cur_input_sequence); auto cot_seq = EigenMatrix::From(cur_output_sequence); @@ -198,11 +222,30 @@ class RowConvGradKernel auto *dx = context.Output(framework::GradVarName("X")); auto *d_filter = context.Output(framework::GradVarName("Filter")); - auto input_dim = x->dims()[1]; // 'x' is of size T x N - auto batch_indices = x->lod()[0]; + auto &x_lod = x->lod(); + bool is_tensor = x_lod.empty(); + int batch_size = 0; + if (is_tensor) { + batch_size = x->dims()[0]; + } else { + batch_size = x->lod()[0].size() - 1; + } + framework::Vector batch_indices(batch_size + 1); + int timesteps = 0; + int input_dim = 0; + if (is_tensor) { + for (int i = 0; i < batch_size + 1; i++) { + batch_indices[i] = i; + } + input_dim = x->dims()[2]; + timesteps = x->dims()[1]; + } else { + batch_indices = x->lod()[0]; + input_dim = x->dims()[1]; + } + size_t num_sequence = batch_indices.size() - 1; auto future_context = filter->dims()[0]; - if (d_filter) { d_filter->mutable_data(context.GetPlace()); auto dweights = @@ -213,14 +256,19 @@ class RowConvGradKernel int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); + int current_timesteps = 0; + if (is_tensor) { + current_timesteps = timesteps; + } else { + current_timesteps = end - start; + } Tensor cur_input = x->Slice(start, end); // Current input sequence + cur_input = cur_input.Resize({current_timesteps, input_dim}); Tensor cur_doutput = d_out->Slice(start, end); // Current output grad sequence - + cur_doutput = cur_doutput.Resize({current_timesteps, input_dim}); auto cur_ip = EigenMatrix::From(cur_input); auto cur_dout = EigenMatrix::From(cur_doutput); - int current_timesteps = end - start; - for (int k = 0; k < current_timesteps; k++) { // For different time steps in the same sequence for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); @@ -241,15 +289,23 @@ class RowConvGradKernel int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); + int current_timesteps = 0; + if (is_tensor) { + current_timesteps = timesteps; + } else { + current_timesteps = end - start; + } + Tensor cur_doutput = d_out->Slice(start, end); // Current output grad sequence + cur_doutput = cur_doutput.Resize({current_timesteps, input_dim}); Tensor cur_dinput = dx->Slice(start, end); // Current input grad sequence + cur_dinput = cur_dinput.Resize({current_timesteps, input_dim}); auto cur_dout = EigenMatrix::From(cur_doutput); auto cur_dip = EigenMatrix::From(cur_dinput); cur_dip.setZero(); - int current_timesteps = end - start; for (int k = 0; k < current_timesteps; k++) { // For different time steps in the same sequence diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index 9ae80da6550bcef39c07f05e35d4153c24738f09..a712878854298bc2eb372be155e1bd512aba7037 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -1,5 +1,4 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -47,11 +46,11 @@ __global__ void RowConvForwardSharedMemory(const T *in, const T *wt, (d < input_dim) ? wt[thy * input_dim + d] : static_cast(0); } __syncthreads(); - for (size_t i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { T sum = 0; for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); @@ -77,11 +76,11 @@ __global__ void RowConvForward(const T *in, const T *wt, int num_sequence, int thy = threadIdx.y; if (d >= input_dim) return; - for (size_t i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { T sum = 0; for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); @@ -114,10 +113,12 @@ __global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt, } __syncthreads(); + int current_timesteps = 0; for (int i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); - int current_timesteps = end - start; + current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { T sum = 0; for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { @@ -142,10 +143,13 @@ __global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence, int thy = threadIdx.y; if (d >= input_dim) return; + int current_timesteps = 0; + for (int i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); - int current_timesteps = end - start; + current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { T sum = 0; for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { @@ -175,7 +179,6 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, int xdim_sh_in = block_y; int xdim_sh_dout = block_y; - // int xdim_sh_dfilter = future_context; int ydim_sh_in = block_x; int ydim_sh_dout = block_x + future_context - 1; int ydim_sh_dfilter = block_y; @@ -197,6 +200,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); int current_timesteps = end - start; + int scaled_cur_steps = ((current_timesteps + block_x - 1) / block_x) * block_x; @@ -258,11 +262,11 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, // NOTE(zcd): temporary solution unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); - for (int i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); int current_timesteps = end - start; + int scaled_cur_steps = ((current_timesteps + block_x - 1) / block_x) * block_x; @@ -310,9 +314,26 @@ class RowConvKernel const T *in = X->data(); const T *weight = Filter->data(); T *out = Out->mutable_data(context.GetPlace()); + bool is_tensor = X->lod().empty(); + int batch_size = 0; + if (is_tensor) { + batch_size = X->dims()[0]; + } else { + batch_size = X->lod()[0].size() - 1; + } + int input_dim = 0; + framework::Vector batch_indices(batch_size + 1); + int timesteps = X->dims()[1]; + if (is_tensor) { + for (int i = 0; i < batch_size + 1; i++) { + batch_indices[i] = i * timesteps; + } + input_dim = X->dims()[2]; + } else { + batch_indices = X->lod()[0]; + input_dim = X->dims()[1]; + } - auto batch_indices = X->lod()[0]; - int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); @@ -348,9 +369,27 @@ class RowConvGradKernel Tensor *dX = context.Output(framework::GradVarName("X")); Tensor *dFilter = context.Output(framework::GradVarName("Filter")); + int batch_size = 0; + bool is_tensor = X->lod().empty(); + if (is_tensor) { + batch_size = X->dims()[0]; + } else { + batch_size = X->lod()[0].size() - 1; + } - auto batch_indices = X->lod()[0]; - int input_dim = X->dims()[1]; + int input_dim = 0; + framework::Vector batch_indices(batch_size + 1); + int timesteps = X->dims()[1]; + if (is_tensor) { + for (int i = 0; i < batch_size + 1; i++) { + batch_indices[i] = i * timesteps; + } + input_dim = X->dims()[2]; + } else { + batch_indices = X->lod()[0]; + input_dim = X->dims()[1]; + } + // int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/test_row_conv_op.py b/python/paddle/fluid/tests/unittests/test_row_conv_op.py index 2f13f067ef313685227c7de9a49fae8640ca6b32..301d05260e0ae0852f420565edbffc77c51e1b38 100644 --- a/python/paddle/fluid/tests/unittests/test_row_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_row_conv_op.py @@ -94,7 +94,7 @@ class TestRowConvOp2(OpTest): self.check_output() #max_relative_error is increased from 0.05 to 0.06 as for higher - #dimensional input, the dX on CPU for some values has max_rel_error + #dimensional input, the dX on CPU for some values has max_rel_error #slightly more than 0.05 def test_check_grad_normal(self): self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.06) @@ -108,5 +108,52 @@ class TestRowConvOp2(OpTest): ['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Filter')) +def row_conv_foward_Tensor(x, wt): + out = np.zeros_like(x) + num_sequence = x.shape[0] + timesteps = x.shape[1] + context_length = wt.shape[0] + for i in range(num_sequence): + cur_in = x[i:i + 1, :][0] + cur_out = out[i:i + 1, :][0] + for j in range(timesteps): + for k in range(context_length): + if j + k >= timesteps: + continue + cur_out[j, :] += cur_in[j + k, :] * wt[k, :] + return out + + +class TestRowOpWithTensorInput(OpTest): + def setUp(self): + self.op_type = "row_conv" + length = [3, 2, 4] + B = 2 + T = sum(length) + D = 16 + context_length = 2 + + x = np.random.random((B, T, D)).astype("float32") + wt = np.random.random((context_length, D)).astype("float32") + self.inputs = {'X': x, 'Filter': wt} + + out = row_conv_foward_Tensor(x, wt) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Filter'], 'Out', max_relative_error=0.05, no_grad_set=set('X')) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_wt(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Filter')) + + if __name__ == '__main__': unittest.main()