diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 011d45c396579a26a804a4cf2ecd50734e7df945..cc3fbd587668b17b7edde50b157adca83e81eddc 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -37,18 +37,23 @@ class MaxSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, T pad_value, - framework::Tensor* output, framework::Tensor* index) { + framework::LoDTensor* output, framework::Tensor* index) { auto in_dims = input.dims(); auto out_dims = output->dims(); auto idx_dims = index->dims(); - PADDLE_ENFORCE_GT(in_dims.size(), 1); - PADDLE_ENFORCE_GT(out_dims.size(), 1); + PADDLE_ENFORCE_GT(in_dims.size(), 1, + "The rank of input shall be greater than 1."); + PADDLE_ENFORCE_GT(out_dims.size(), 1, + "The rank of output shall be greater than 1."); for (int64_t i = 1; i < in_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); + PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i], + "The dimension of input and output shall be same."); } - PADDLE_ENFORCE_EQ(idx_dims, out_dims); + PADDLE_ENFORCE_EQ(idx_dims, out_dims, + "The dimension of index and output shall be same."); - auto starts = input.lod()[0]; + auto lod_level = input.lod().size(); + auto starts = input.lod()[lod_level - 1]; const T* in_data = input.data(); T* out_data = output->data(); int* max_index = index->data(); @@ -85,16 +90,20 @@ class MaxSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, T pad_value, - framework::Tensor* output, framework::Tensor* index) { + framework::LoDTensor* output, framework::Tensor* index) { auto in_dims = input.dims(); auto out_dims = output->dims(); - PADDLE_ENFORCE_GT(in_dims.size(), 1); - PADDLE_ENFORCE_GT(out_dims.size(), 1); + PADDLE_ENFORCE_GT(in_dims.size(), 1, + "The rank of input shall be greater than 1."); + PADDLE_ENFORCE_GT(out_dims.size(), 1, + "The rank of output shall be greater than 1."); for (int64_t i = 1; i < in_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); + PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i], + "The dimension of input and output shall be same."); } - auto starts = input.lod()[0]; + auto lod_level = input.lod().size(); + auto starts = input.lod()[lod_level - 1]; const T* in_data = input.data(); T* out_data = output->data(); @@ -123,18 +132,23 @@ template class MaxSeqPoolGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& out_grad, + const framework::LoDTensor& out_grad, const framework::Tensor& index, framework::LoDTensor* in_grad) { auto og_dims = out_grad.dims(); auto ig_dims = in_grad->dims(); auto idx_dims = index.dims(); - PADDLE_ENFORCE_GT(og_dims.size(), 1); - PADDLE_ENFORCE_GT(ig_dims.size(), 1); + PADDLE_ENFORCE_GT(og_dims.size(), 1, + "The rank of output@Grad shall be greater than 1."); + PADDLE_ENFORCE_GT(ig_dims.size(), 1, + "The rank of input@Grad shall be greater than 1."); for (int64_t i = 1; i < og_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]); + PADDLE_ENFORCE_EQ( + og_dims[i], ig_dims[i], + "The dimension of input@Grad and output@Grad shall be same."); } - PADDLE_ENFORCE_EQ(idx_dims, og_dims); + PADDLE_ENFORCE_EQ(idx_dims, og_dims, + "The dimension of index and output@Grad shall be same."); const T* og_data = out_grad.data(); const int* max_index = index.data(); @@ -159,14 +173,15 @@ class LastSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, T pad_value, - framework::Tensor* output) { + framework::LoDTensor* output) { // Create pointers to input and output data auto* in_data = input.data(); auto* out_data = output->data(); // Calculate the size of each item in sequence int64_t item_size = input.numel() / input.dims()[0]; - auto lod = input.lod()[0]; + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; int seq_num = static_cast(lod.size()) - 1; for (int i = 0; i < seq_num; ++i) { // Calculate the length of each sequence @@ -191,14 +206,15 @@ class FirstSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, T pad_value, - framework::Tensor* output) { + framework::LoDTensor* output) { // Create pointers to input and output data auto* in_data = input.data(); auto* out_data = output->data(); // Calculate the size of each item in sequence int64_t item_size = input.numel() / input.dims()[0]; - auto lod = input.lod()[0]; + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; int seq_num = static_cast(lod.size()) - 1; for (int i = 0; i < seq_num; ++i) { // Calculate the length of each sequence @@ -222,12 +238,15 @@ template class SumSeqPoolGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& out_grad, + const framework::LoDTensor& out_grad, framework::LoDTensor* in_grad) { - auto lod = in_grad->lod()[0]; + auto lod_level = in_grad->lod().size(); + auto lod = in_grad->lod()[lod_level - 1]; int64_t out_w = out_grad.numel() / out_grad.dims()[0]; int64_t in_w = in_grad->numel() / in_grad->dims()[0]; - PADDLE_ENFORCE(in_w == out_w); + PADDLE_ENFORCE_EQ( + in_w, out_w, + "The feature size of input@Grad and output@Grad shall be same."); const T* out_g_data = out_grad.data(); T* in_g_data = in_grad->mutable_data(context.GetPlace()); auto blas = math::GetBlas(context); @@ -250,8 +269,9 @@ class SequencePoolFunctor { /* max pool has index output */ void operator()(const platform::CPUDeviceContext& context, const std::string pooltype, T pad_value, - const framework::LoDTensor& input, framework::Tensor* output, - bool is_test, framework::Tensor* index = nullptr) { + const framework::LoDTensor& input, + framework::LoDTensor* output, bool is_test, + framework::Tensor* index = nullptr) { if (pooltype == "MAX") { if (is_test) { math::MaxSeqPoolFunctor max_pool; @@ -272,11 +292,13 @@ class SequencePoolFunctor { first_pool(context, input, pad_value, output); return; } - - auto lod = input.lod()[0]; + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; if (pooltype == "SUM") { auto place = context.GetPlace(); - PADDLE_ENFORCE(platform::is_cpu_place(place)); + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(place), true, + "Sequence_pool should run on CPU Device when pooltype is SUM"); const T* src = input.data(); T* dst = output->mutable_data(place); jit::seq_pool_attr_t attr( @@ -330,7 +352,8 @@ template class SequencePoolGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const std::string pooltype, const framework::Tensor& out_grad, + const std::string pooltype, + const framework::LoDTensor& out_grad, framework::LoDTensor* in_grad, /* max pool has index */ const framework::Tensor* index = nullptr) { @@ -352,7 +375,8 @@ class SequencePoolGradFunctor { return; } - auto lod = in_grad->lod()[0]; + auto lod_level = in_grad->lod().size(); + auto lod = in_grad->lod()[lod_level - 1]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { if (lod[i] == lod[i + 1]) continue; diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 4de99ba677d5108e8b70e71e3dfefa17b6e18beb..91545131e4cbb5d6dcae9c111e97598ee54cc898 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -159,9 +159,11 @@ class SequencePoolFunctor { public: void operator()(const platform::CUDADeviceContext& context, const std::string pooltype, T pad_value, - const framework::LoDTensor& input, framework::Tensor* output, - bool is_test, framework::Tensor* index = nullptr) { - auto& lod = input.lod()[0]; + const framework::LoDTensor& input, + framework::LoDTensor* output, bool is_test, + framework::Tensor* index = nullptr) { + auto lod_level = input.lod().size(); + auto& lod = input.lod()[lod_level - 1]; const size_t item_dim = output->numel() / output->dims()[0]; dim3 threads(1024, 1); dim3 grid(lod.size(), 1); @@ -319,11 +321,13 @@ template class SequencePoolGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const std::string pooltype, const framework::Tensor& out_grad, + const std::string pooltype, + const framework::LoDTensor& out_grad, framework::LoDTensor* in_grad, /* max pool has index */ const framework::Tensor* index = nullptr) { - auto& lod = in_grad->lod()[0]; + auto lod_level = in_grad->lod().size(); + auto& lod = in_grad->lod()[lod_level - 1]; const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; dim3 threads(1024, 1); dim3 grid(lod.size(), 1); diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h index 1dc02eae201413b9483b31129578be144f175aa3..847d0bca951a7e54a74a6c803a4f24d50672228f 100644 --- a/paddle/fluid/operators/math/sequence_pooling.h +++ b/paddle/fluid/operators/math/sequence_pooling.h @@ -28,7 +28,7 @@ class SequencePoolFunctor { /* max pool has index output */ void operator()(const DeviceContext& context, const std::string pooltype, T pad_value, const framework::LoDTensor& input, - framework::Tensor* output, bool is_test = false, + framework::LoDTensor* output, bool is_test = false, framework::Tensor* index = nullptr); }; @@ -36,7 +36,7 @@ template class SequencePoolGradFunctor { public: void operator()(const DeviceContext& context, const std::string pooltype, - const framework::Tensor& out_grad, + const framework::LoDTensor& out_grad, framework::LoDTensor* in_grad, /* max pool has index */ const framework::Tensor* index = nullptr); diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index f3193fdc55609ee0cc608367c654b9d506217b6c..51e354dcd175845c3db2cce78dac6039361aed08 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -24,14 +24,15 @@ class SequencePoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SequencePoolOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SequencePoolOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of SequencePoolOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of SequencePoolOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); if (ctx->Attrs().Get("pooltype") == "MAX") { - PADDLE_ENFORCE(ctx->HasOutput("MaxIndex"), - "Output(MaxIndex) of SequencePoolOp should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("MaxIndex"), true, + "Output(MaxIndex) of SequencePoolOp should not be null."); ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X")); } } @@ -102,9 +103,10 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Gradient of Out should not be null."); - PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "Gradient of Out should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "The input X should not be null."); auto og_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h index c32734808c39313fcf0a0e624d246f2e52838edf..3eec4df121046e6c269cd950234c06b31b57d5a2 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h @@ -30,19 +30,30 @@ class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* out = context.Output("Out"); + auto* out = context.Output("Out"); std::string pooltype = context.Attr("pooltype"); T pad_value = static_cast(context.Attr("pad_value")); auto dims = in->dims(); auto lod = in->lod(); + auto lod_level = lod.size(); // InferShape by lod - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_GE(lod_level, 1UL, + "The lod level of input shall be 1 at least."); + PADDLE_ENFORCE_LE(lod_level, 2UL, + "The lod level of input shall be no more than 2."); PADDLE_ENFORCE_GE( dims[0], - /*batch size = */ static_cast(lod[0].size() - 1), + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), "The first dimension of Input(X) must be large than batch size."); - dims[0] = lod[0].size() - 1; + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1], lod[1].size() - 1, + "The input lod information is illegal."); + framework::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; out->Resize({dims}); out->mutable_data(context.GetPlace()); Tensor* index = nullptr; @@ -68,7 +79,7 @@ template class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* out_g = context.Input(framework::GradVarName("Out")); + auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); std::string pooltype = context.Attr("pooltype"); const Tensor* index = nullptr; diff --git a/python/paddle/fluid/tests/unittests/test_seq_pool.py b/python/paddle/fluid/tests/unittests/test_seq_pool.py index aa801b1f5d8c7e7c8acec7096db7010a058451ff..2de5d0345912ace44858de1be52dece846ef879a 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_pool.py +++ b/python/paddle/fluid/tests/unittests/test_seq_pool.py @@ -21,30 +21,33 @@ from test_reorder_lod_tensor import convert_to_offset def compute_seqpool_sum(x, offset, out, pad_value=0.0): - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = pad_value else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] + sub_x = x[offset[level][i]:offset[level][i + 1], :] out[i] = sub_x.sum(axis=0) def compute_seqpool_avg(x, offset, out, pad_value=0.0): - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = pad_value else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] + sub_x = x[offset[level][i]:offset[level][i + 1], :] out[i] = sub_x.mean(axis=0) def compute_seqpool_sqrt(x, offset, out, pad_value=0.0): - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = pad_value else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] - seq_len = offset[0][i + 1] - offset[0][i] + sub_x = x[offset[level][i]:offset[level][i + 1], :] + seq_len = offset[level][i + 1] - offset[level][i] out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len) @@ -56,9 +59,10 @@ class TestSeqAvgPool(OpTest): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') lod = self.set_lod() + level = len(lod) - 1 self.inputs = {'X': (x, lod)} offset = convert_to_offset(lod) - out = np.zeros((len(lod[0]), 23)).astype('float32') + out = np.zeros((len(lod[level]), 23)).astype('float32') self.outputs = {'Out': out} return x, offset, out @@ -69,14 +73,18 @@ class TestSeqAvgPool(OpTest): def setUp(self): x, offset, out = self.set_data() self.compute(x, offset, out) + if len(offset) > 1: + self.outputs = {'Out': (out, [self.set_lod()[0]])} def test_check_output(self): self.check_output() def test_check_grad(self): # Remove MaxIndex after check_grad is refined. + out = self.outputs['Out'] + if isinstance(out, tuple): out = out[0] self.outputs['MaxIndex'] = \ - np.zeros(self.outputs['Out'].shape).astype('int32') + np.zeros(out.shape).astype('int32') self.check_grad(["X"], "Out") @@ -85,6 +93,11 @@ class TestSeqAvgPoolLen0(TestSeqAvgPool): return [[0, 4, 0, 7, 0]] +class TestSeqAvgPoolLen0LoDLevel2(TestSeqAvgPool): + def set_lod(self): + return [[2, 0, 1, 2], [0, 4, 0, 7, 0]] + + class TestSeqSumPool(TestSeqAvgPool): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.1, 'pooltype': "SUM"} @@ -96,6 +109,11 @@ class TestSeqSumPoolLen0(TestSeqSumPool): return [[0, 4, 0, 7, 0]] +class TestSeqSumPoolLen0LoDLevel2(TestSeqSumPool): + def set_lod(self): + return [[2, 0, 1, 2], [0, 4, 0, 7, 0]] + + class TestSeqMaxPool(TestSeqAvgPool): def set_lod(self): return [[13]] @@ -104,25 +122,27 @@ class TestSeqMaxPool(TestSeqAvgPool): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [13, 23]).astype('float32') lod = self.set_lod() + level = len(lod) - 1 offset = convert_to_offset(lod) - for i in range(len(offset[0]) - 1): - l = offset[0][i + 1] - offset[0][i] + for i in range(len(offset[level]) - 1): + l = offset[level][i + 1] - offset[level][i] if l > 0: - x[offset[0][i] + np.random.randint(l), :] += 2.0 + x[offset[level][i] + np.random.randint(l), :] += 2.0 self.inputs = {'X': (x, lod)} - out = np.zeros((len(lod[0]), 23)).astype('float32') + out = np.zeros((len(lod[level]), 23)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): self.attrs = {"pad_value": 0.5, 'pooltype': "MAX"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] + sub_x = x[offset[level][i]:offset[level][i + 1], :] out[i] = np.amax(sub_x, axis=0) @@ -131,6 +151,11 @@ class TestSeqMaxPoolLen0(TestSeqMaxPool): return [[0, 1, 1, 5, 6, 0]] +class TestSeqMaxPoolLen0LoDLevel2(TestSeqMaxPool): + def set_lod(self): + return [[2, 0, 3, 1], [0, 1, 1, 5, 6, 0]] + + class TestSeqSqrtPool(TestSeqAvgPool): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"} @@ -142,14 +167,20 @@ class TestSeqSqrtPoolLen0(TestSeqSqrtPool): return [[0, 7, 0, 2, 2, 0]] +class TestSeqSqrtPoolLen0LoDLevel2(TestSeqSqrtPool): + def set_lod(self): + return [[1, 2, 0, 3], [0, 7, 0, 2, 2, 0]] + + class TestSeqLastPool(TestSeqAvgPool): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] + sub_x = x[offset[level][i]:offset[level][i + 1], :] out[i] = sub_x[-1, :] @@ -158,14 +189,20 @@ class TestSeqLastPoolLen0(TestSeqLastPool): return [[0, 3, 4, 0, 4, 0]] +class TestSeqLastPoolLen0LoDLevel2(TestSeqLastPool): + def set_lod(self): + return [[1, 0, 2, 3], [0, 3, 4, 0, 4, 0]] + + class TestSeqFirstPool(TestSeqAvgPool): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.3, 'pooltype': "FIRST"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] else: - sub_x = x[offset[0][i]:offset[0][i + 1], :] + sub_x = x[offset[level][i]:offset[level][i + 1], :] out[i] = sub_x[0, :] @@ -174,6 +211,11 @@ class TestSeqFirstPoolLen0(TestSeqFirstPool): return [[0, 2, 0, 3, 6, 0]] +class TestSeqFirstPoolLen0LoDLevel2(TestSeqFirstPool): + def set_lod(self): + return [[1, 0, 2, 3], [0, 2, 0, 3, 6, 0]] + + class TestSeqAvgPool2D(TestSeqAvgPool): def set_lod(self): return [[4, 1, 3, 5]] @@ -182,20 +224,22 @@ class TestSeqAvgPool2D(TestSeqAvgPool): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') lod = self.set_lod() + level = len(lod) - 1 self.inputs = {'X': (x, lod)} offset = convert_to_offset(lod) - out = np.zeros((len(lod[0]), 3, 17)).astype('float32') + out = np.zeros((len(lod[level]), 3, 17)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "AVERAGE"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 17)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) @@ -205,14 +249,20 @@ class TestSeqAvgPool2DLen0(TestSeqAvgPool2D): return [[0, 5, 0, 8, 0]] +class TestSeqAvgPool2DLen0LoDLevel2(TestSeqAvgPool2D): + def set_lod(self): + return [[1, 0, 4], [0, 5, 0, 8, 0]] + + class TestSeqSumPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.2, 'pooltype': "SUM"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 17)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) @@ -222,23 +272,32 @@ class TestSeqSumPool2DLen0(TestSeqSumPool2D): return [[0, 8, 0, 5, 0]] +class TestSeqSumPool2DLen0LoDLevel2(TestSeqSumPool2D): + def set_lod(self): + return [[1, 0, 4], [0, 8, 0, 5, 0]] + + class TestSeqSqrtPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 17)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 17)) - seq_len = offset[0][i + 1] - offset[0][i] + seq_len = offset[level][i + 1] - offset[level][i] out[i] = np.reshape( sub_x.sum(axis=0) / np.sqrt(seq_len), (3, 17)) def test_check_grad(self): # Remove MaxIndex after check_grad is refined. + out = self.outputs['Out'] + if isinstance(out, tuple): + out = out[0] self.outputs['MaxIndex'] = \ - np.zeros(self.outputs['Out'].shape).astype('int32') + np.zeros(out.shape).astype('int32') self.check_grad(["X"], "Out", max_relative_error=0.06) @@ -247,6 +306,11 @@ class TestSeqSqrtPool2DLen0(TestSeqSqrtPool2D): return [[0, 8, 0, 5, 0]] +class TestSeqSqrtPool2DLen0LoDLevel2(TestSeqSqrtPool2D): + def set_lod(self): + return [[1, 0, 2, 2], [0, 8, 0, 5, 0]] + + class TestSeqMaxPool2D(TestSeqAvgPool2D): def set_lod(self): return [[4, 1, 3, 5]] @@ -255,25 +319,27 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32') self.lod = self.set_lod() + level = len(self.lod) - 1 self.inputs = {'X': (x, self.lod)} offset = convert_to_offset(self.lod) - for i in range(len(offset[0]) - 1): - l = offset[0][i + 1] - offset[0][i] + for i in range(len(offset[level]) - 1): + l = offset[level][i + 1] - offset[level][i] if l == 0: continue - x[offset[0][i] + np.random.randint(l), :] += 1.0 + x[offset[level][i] + np.random.randint(l), :] += 1.0 - out = np.zeros((len(self.lod[0]), 3, 11)).astype('float32') + out = np.zeros((len(self.lod[level]), 3, 11)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "MAX"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 11)) continue - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 11)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) @@ -283,14 +349,20 @@ class TestSeqMaxPool2DLen0(TestSeqMaxPool2D): return [[0, 3, 0, 10, 0]] +class TestSeqMaxPool2DLen0LoDLevel2(TestSeqMaxPool2D): + def set_lod(self): + return [[1, 0, 2, 2], [0, 3, 0, 10, 0]] + + class TestSeqMaxPool2DInference(TestSeqMaxPool2D): def compute(self, x, offset, out): self.attrs = {"pad_value": 1.0, 'pooltype': "MAX", 'is_test': True} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 11)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 11)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) @@ -305,14 +377,20 @@ class TestSeqMaxPool2DInferenceLen0(TestSeqMaxPool2DInference): return [[0, 3, 0, 10, 0]] +class TestSeqMaxPool2DInferenceLen0LoDLevel2(TestSeqMaxPool2DInference): + def set_lod(self): + return [[1, 0, 2, 2], [0, 3, 0, 10, 0]] + + class TestSeqLastPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 17)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[-1, :], (3, 17)) @@ -322,14 +400,20 @@ class TestSeqLastPool2DLen0(TestSeqLastPool2D): return [[0, 3, 0, 1, 9, 0]] +class TestSeqLastPool2DLen0LoDLevel2(TestSeqLastPool2D): + def set_lod(self): + return [[1, 0, 2, 3], [0, 3, 0, 1, 9, 0]] + + class TestSeqFirstPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): self.attrs = {"pad_value": 0.0, 'pooltype': "FIRST"} - for i in range(len(offset[0]) - 1): - if offset[0][i] == offset[0][i + 1]: + level = len(offset) - 1 + for i in range(len(offset[level]) - 1): + if offset[level][i] == offset[level][i + 1]: out[i] = self.attrs["pad_value"] * np.ones((3, 17)) else: - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[0, :], (3, 17)) @@ -339,5 +423,10 @@ class TestSeqFirstPool2DLen0(TestSeqFirstPool2D): return [[0, 3, 0, 3, 7, 0]] +class TestSeqFirstPool2DLen0LoDLevel2(TestSeqFirstPool2D): + def set_lod(self): + return [[1, 0, 2, 3], [0, 3, 0, 3, 7, 0]] + + if __name__ == '__main__': unittest.main()