From 834b82f109ee3a9e6370dc7e81b287d8f6b02754 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 18 Oct 2017 15:23:36 +0800 Subject: [PATCH] fix sequence_project_op forward and backward --- paddle/operators/sequence_project_op.cc | 28 +- paddle/operators/sequence_project_op.h | 267 ++++++++++++------ .../v2/framework/tests/test_seq_project.py | 123 ++++++-- 3 files changed, 292 insertions(+), 126 deletions(-) diff --git a/paddle/operators/sequence_project_op.cc b/paddle/operators/sequence_project_op.cc index c894f3f1f8f..b1351e8ac53 100644 --- a/paddle/operators/sequence_project_op.cc +++ b/paddle/operators/sequence_project_op.cc @@ -38,24 +38,23 @@ class SequenceProjectOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasInput("PaddingData"), "Output(PaddingData) of SequenceProjectOp should not be null."); - framework::DDim padding_dim = ctx->GetOutputDim("PaddingData"); + framework::DDim padding_dim = ctx->GetInputDim("PaddingData"); int up_pad = std::max(0, -context_start); int down_pad = std::max(0, context_start + context_length - 1); int total_pad = up_pad + down_pad; int input_width = static_cast(in_dims[1]); + if (context_start == 0 && context_length == 1) { + PADDLE_THROW( + "if context_start == 0 && context_length == 1, padding_trainable " + "should be false."); + } PADDLE_ENFORCE(padding_dim.size() == 2, "Input(PaddingData) should be 2-D tensor."); PADDLE_ENFORCE( padding_dim[0] == total_pad && padding_dim[1] == input_width, "Input(PaddingData)'s shape is not consistent with 'context_start' " "and 'context_length'."); - - if (context_start == 0 && context_length == 1) { - PADDLE_THROW( - "if context_start == 0 && context_length == 1, padding_trainable " - "should be false."); - } } in_dims[1] = in_dims[1] * context_length; @@ -74,9 +73,11 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); if (ctx->Attrs().Get("padding_trainable")) { - PADDLE_ENFORCE( - ctx->HasOutput("PaddingData"), - "Output(PaddingData) of SequenceProjectOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")), + "Output(PaddingData@GRAD) of SequenceProjectGradOp should " + "not be null."); + auto padding_dims = ctx->GetInputDim("PaddingData"); + ctx->SetOutputDim(framework::GradVarName("PaddingData"), padding_dims); } ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } @@ -93,8 +94,8 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput( "Out", "A float LoDTensor, the variable-length output of SequenceProjectOp."); - AddOutput("PaddingData", - "A float LoDTensor, the padding data of SequenceProjectOp."); + AddInput("PaddingData", // PaddingData can be a float tensor + "A float LoDTensor, the padding data of SequenceProjectOp."); AddAttr("padding_trainable", "(bool, default false) the padding data of SequenceProjectOp " @@ -110,7 +111,8 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("context_stride", "(int, default 1) the xx of SequenceProjectOp.") .SetDefault(1) - .GreaterThan(0); + .GreaterThan( + 0); // Currently, sequence_project_op only support context_stride=1 AddComment(R"DOC( SequenceProjectOp projects features of context_length time-steps of each instance. diff --git a/paddle/operators/sequence_project_op.h b/paddle/operators/sequence_project_op.h index 0a1b647070d..6cc57d894bb 100644 --- a/paddle/operators/sequence_project_op.h +++ b/paddle/operators/sequence_project_op.h @@ -23,6 +23,9 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +template +using EigenVector = framework::EigenVector; template using EigenMatrix = framework::EigenMatrix; @@ -34,6 +37,13 @@ class SequenceProjectKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); + + // need discuss, is it necessary to set zeros ? + // Because if padding_trainable is false, padding data should be zeros. + auto temp = framework::EigenVector::Flatten(*out); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + auto place = context.GetEigenDevice(); int context_start = context.Attr("context_start"); @@ -45,10 +55,10 @@ class SequenceProjectKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, "Only support one level sequence now."); auto lod_level_0 = in->lod()[0]; - int64_t input_stride = in->dims()[1]; - int64_t output_stride = out->dims()[1]; - int64_t padding_stride = 0; - PADDLE_ENFORCE(input_stride * context_length == output_stride, + int64_t input_width = in->dims()[1]; + int64_t output_width = out->dims()[1]; + int64_t padding_width = 0; + PADDLE_ENFORCE(input_width * context_length == output_width, "Input size and pooling size should be consistent."); const LoDTensor* padding_data = nullptr; @@ -56,73 +66,105 @@ class SequenceProjectKernel : public framework::OpKernel { padding_data = context.Input("PaddingData"); PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL, "Only support one level sequence now."); - padding_stride = padding_data->dims()[1]; - PADDLE_ENFORCE(padding_stride == input_stride, + padding_width = padding_data->dims()[1]; + PADDLE_ENFORCE(padding_width == input_width, "Input size and pooling size should be consistent."); } int up_pad = std::max(0, -context_start); int down_pad = std::max(0, context_start + context_length - 1); + int sequence_height, sequence_width; + int input_row_begin, input_row_end; paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor in_t = in->Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + input_row_begin = (context_start > 0) + ? static_cast(lod_level_0[i]) + context_start + : static_cast(lod_level_0[i]); + input_row_end = static_cast(lod_level_0[i + 1]); + Tensor out_t = out->Slice(static_cast(lod_level_0[i]), static_cast(lod_level_0[i + 1])); - int sequence_height = in_t.dims()[0]; - int sequence_width = in_t.dims()[1]; + sequence_height = static_cast(out_t.dims()[0]); + sequence_width = static_cast(in->dims()[1]); + std::vector output_shape( {sequence_height, 1, 1, context_length, sequence_width}); // output_height, output_width, - // input_channels, - // filter_height, filter_width + // input_channels, filter_height, filter_width out_t.Resize(framework::make_ddim(output_shape)); - std::vector input_shape( - {1, sequence_height, - sequence_width}); // input_channels, input_height, input_width - in_t.Resize(framework::make_ddim(input_shape)); - for (int j = 0; j < context_length; ++j) { + + if (input_row_begin < input_row_end) { + Tensor in_t = in->Slice(input_row_begin, input_row_end); + std::vector input_shape( + {1, input_row_end - input_row_begin, + sequence_width}); // input_channels, input_height, input_width + in_t.Resize(framework::make_ddim(input_shape)); + im2col_ocf(context.device_context(), in_t, out_t, /*stride_height*/ context_stride, /*stride_width*/ 0, up_pad, down_pad); - if (padding_trainable) { - // add up trainable data - out_t.Resize(framework::make_ddim( - {sequence_height * context_length, sequence_width})); - if (up_pad != 0) { - for (int k = 0; k < up_pad; ++k) { - Tensor out_t_sub = out_t.Slice( - k * context_length, k * context_length + (up_pad - k)); - Tensor w_sub = padding_data->Slice(k, context_length - k); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(place) = w_sub_e; - } + } + + if (padding_trainable) { + // add up trainable data + out_t.Resize(framework::make_ddim( + {sequence_height * context_length, sequence_width})); + + if (up_pad > 0) { // add up pad + int padding_rows = std::min( + up_pad, static_cast(lod_level_0[i + 1] - lod_level_0[i])); + + for (int k = 0; k < padding_rows; ++k) { + int padding_size = + k + context_length < up_pad ? context_length : up_pad - k; + Tensor out_t_sub = out_t.Slice( + k * context_length, k * context_length + padding_size); + Tensor w_sub = padding_data->Slice(k, k + padding_size); + // in this block, using EigenVector::Flatten is ok too. + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + out_t_sub_e.device(place) = w_sub_e; } - if (down_pad != 0) { - int k = - (sequence_height + up_pad - context_length) / context_stride + - 1; - for (int t = 0; t + k < sequence_height; ++t) { - Tensor out_t_sub = - out_t.Slice((k + t) * context_length * sequence_width - - t * sequence_width, - (k + t) * context_length * sequence_width); - Tensor w_sub = padding_data->Slice(up_pad + 1, up_pad + 1 + t); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(place) = w_sub_e; + } + if (down_pad > 0) { // add down pad + int down_pad_begin_row = + std::max(0, + (sequence_height - context_start - context_length) + 1) + + 1; + int padding_begin = std::max(0, context_start - sequence_height); + int padding_size = + sequence_height - context_start >= context_length + ? 1 + : context_length - (sequence_height - context_start); + if (context_start >= sequence_height) padding_size = context_length; + int padding_idx = padding_begin; + for (int t = 0; t + down_pad_begin_row <= sequence_height; + ++t, ++padding_size) { + if (context_start >= sequence_height) padding_size = context_length; + if (padding_size > context_length) { + padding_size = context_length; + padding_idx++; } + if (padding_begin > 0 || sequence_height == context_start) + padding_idx = padding_begin + t; + Tensor out_t_sub = out_t.Slice( + (down_pad_begin_row + t) * context_length - padding_size, + (down_pad_begin_row + t) * context_length); + Tensor w_sub = padding_data->Slice( + up_pad + padding_idx, up_pad + padding_idx + padding_size); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + out_t_sub_e.device(place) = w_sub_e; } - out_t.Resize(framework::make_ddim( - {sequence_height, context_length * sequence_width})); } } + out_t.Resize(framework::make_ddim( + {sequence_height, context_length * sequence_width})); } } }; @@ -131,95 +173,136 @@ template class SequenceProjectGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - // auto* in = context.Input("X"); auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); + auto* in = context.Input("X"); in_g->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); int context_start = context.Attr("context_start"); int context_length = context.Attr("context_length"); bool padding_trainable = context.Attr("padding_trainable"); - int context_stride = context.Attr("context_stride"); + int context_stride = context.Attr("context_stride"); // InferShape by in_lod - PADDLE_ENFORCE_EQ(in_g->lod().size(), 1UL, + PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, "Only support one level sequence now."); - auto lod_g_level_0 = in_g->lod()[0]; + auto lod_g_level_0 = in->lod()[0]; int64_t input_width = in_g->dims()[1]; int64_t output_width = out_g->dims()[1]; int64_t padding_width = 0; PADDLE_ENFORCE(input_width * context_length == output_width, "Input size and pooling size should be consistent."); - LoDTensor* padding_data = nullptr; + LoDTensor* padding_data_g = nullptr; if (padding_trainable) { - padding_data = context.Output("PaddingData"); - padding_data->mutable_data(context.GetPlace()); - PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL, + padding_data_g = + context.Output(framework::GradVarName("PaddingData")); + padding_data_g->mutable_data(context.GetPlace()); + PADDLE_ENFORCE_EQ(padding_data_g->dims().size(), 2UL, "Only support one level sequence now."); - padding_width = padding_data->dims()[1]; + padding_width = padding_data_g->dims()[1]; PADDLE_ENFORCE(padding_width == input_width, "Input size and pooling size should be consistent."); } int up_pad = std::max(0, -context_start); int down_pad = std::max(0, context_start + context_length - 1); + int sequence_height, sequence_width; + int input_row_begin, input_row_end; paddle::operators::math::Col2ImFunctor< paddle::operators::math::ColFormat::kOCF, Place, float> col2im_ocf; for (int i = 0; i < static_cast(lod_g_level_0.size()) - 1; ++i) { - Tensor in_g_t = in_g->Slice(static_cast(lod_g_level_0[i]), - static_cast(lod_g_level_0[i + 1])); + input_row_begin = (context_start > 0) + ? static_cast(lod_g_level_0[i]) + context_start + : static_cast(lod_g_level_0[i]); + input_row_end = static_cast(lod_g_level_0[i + 1]); + Tensor out_g_t = out_g->Slice(static_cast(lod_g_level_0[i]), static_cast(lod_g_level_0[i + 1])); - int sequence_height = in_g_t.dims()[0]; - int sequence_width = in_g_t.dims()[1]; - - for (int j = 0; j < context_length; ++j) { - if (padding_trainable) { - out_g_t.Resize(framework::make_ddim( - {sequence_height * context_length, sequence_width})); - if (up_pad != 0) { - for (int k = 0; k < up_pad; ++k) { - Tensor out_t_sub = out_g_t.Slice( - k * context_length, k * context_length + (up_pad - k)); - Tensor w_sub = padding_data->Slice(k, context_length - k); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(place) = w_sub_e + out_t_sub_e; - // out_t_sub_e.device(place) = 0; - } + sequence_height = static_cast(out_g_t.dims()[0]); + sequence_width = static_cast(in_g->dims()[1]); + + if (padding_trainable) { + // add up trainable data + out_g_t.Resize(framework::make_ddim( + {sequence_height * context_length, sequence_width})); + + if (up_pad > 0) { // add up pad + int padding_rows = std::min( + up_pad, + static_cast(lod_g_level_0[i + 1] - lod_g_level_0[i])); + + for (int k = 0; k < padding_rows; ++k) { + int padding_size = + k + context_length < up_pad ? context_length : up_pad - k; + Tensor out_t_sub = out_g_t.Slice( + k * context_length, k * context_length + padding_size); + Tensor w_sub = padding_data_g->Slice(k, k + padding_size); + // in this block, using EigenVector::Flatten is ok too. + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + w_sub_e.device(place) = w_sub_e + out_t_sub_e; } - if (down_pad != 0) { - int k = - (sequence_height + up_pad - context_length) / context_stride + - 1; - for (int t = 0; t + k < sequence_height; ++t) { - Tensor out_t_sub = - out_g_t.Slice((k + t) * context_length * sequence_width - - t * sequence_width, - (k + t) * context_length * sequence_width); - Tensor w_sub = padding_data->Slice(up_pad + 1, up_pad + 1 + t); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(place) = w_sub_e + out_t_sub_e; - // out_t_sub_e.device(place) = 0; + } + if (down_pad > 0) { // add down pad + int down_pad_begin_row = + std::max(0, + (sequence_height - context_start - context_length) + 1) + + 1; + int padding_begin = std::max(0, context_start - sequence_height); + int padding_size = + sequence_height - context_start >= context_length + ? 1 + : context_length - (sequence_height - context_start); + if (context_start >= sequence_height) padding_size = context_length; + int padding_idx = padding_begin; + for (int t = 0; t + down_pad_begin_row <= sequence_height; + ++t, ++padding_size) { + if (context_start >= sequence_height) padding_size = context_length; + if (padding_size > context_length) { + padding_size = context_length; + padding_idx++; } + if (padding_begin > 0 || sequence_height == context_start) + padding_idx = padding_begin + t; + Tensor out_t_sub = out_g_t.Slice( + (down_pad_begin_row + t) * context_length - padding_size, + (down_pad_begin_row + t) * context_length); + Tensor w_sub = padding_data_g->Slice( + up_pad + padding_idx, up_pad + padding_idx + padding_size); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + w_sub_e.device(place) = w_sub_e + out_t_sub_e; } } - out_g_t.Resize(framework::make_ddim( - {sequence_height, 1, 1, context_length, sequence_width})); + } + + if (in && input_row_begin < input_row_end) { + Tensor in_t = in_g->Slice(input_row_begin, input_row_end); - col2im_ocf(context.device_context(), in_g_t, out_g_t, + std::vector output_shape( + {sequence_height, 1, 1, context_length, + sequence_width}); // output_height, output_width, + // input_channels, filter_height, filter_width + out_g_t.Resize(framework::make_ddim(output_shape)); + + std::vector input_shape( + {1, input_row_end - input_row_begin, + sequence_width}); // input_channels, input_height, input_width + in_t.Resize(framework::make_ddim(input_shape)); + + col2im_ocf(context.device_context(), in_t, out_g_t, /*stride_height*/ context_stride, /*stride_width*/ 0, up_pad, down_pad); - - // out_g_t back to orign size } + + out_g_t.Resize(framework::make_ddim( + {sequence_height, context_length * sequence_width})); } } }; diff --git a/python/paddle/v2/framework/tests/test_seq_project.py b/python/paddle/v2/framework/tests/test_seq_project.py index 57e01e414db..4dbc02dbdd1 100644 --- a/python/paddle/v2/framework/tests/test_seq_project.py +++ b/python/paddle/v2/framework/tests/test_seq_project.py @@ -1,5 +1,6 @@ import unittest import numpy as np +import random from op_test import OpTest @@ -10,18 +11,22 @@ class TestSeqProject(OpTest): # one level, batch size x = np.random.uniform( 0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32') - lod = [[0, 4, 5, 8, self.input_size[0]]] self.begin_pad = np.max([0, -self.context_start]) self.end_pad = np.max([0, self.context_start + self.context_length - 1]) self.total_pad = self.begin_pad + self.end_pad - w = np.ones((self.total_pad, self.input_size[1])) * 100 - - self.inputs = {'X': (x, lod), 'PaddingData': w} + # w = np.ones((self.total_pad, self.input_size[1])) * 100 + w = np.array(range(self.total_pad * self.input_size[1])) + w.shape = self.total_pad, self.input_size[1] + self.inputs = { + 'X': (x, self.lod), + 'PaddingData': (w, [[0, self.total_pad]]) + } self.attrs = { 'context_start': self.context_start, 'context_length': self.context_length, - 'padding_trainable': self.padding_trainable + 'padding_trainable': self.padding_trainable, + 'context_stride': self.context_stride } out = np.zeros((self.input_size[0], self.input_size[1] * self.context_length)).astype('float32') @@ -30,9 +35,10 @@ class TestSeqProject(OpTest): def compute(self): x, lod = self.inputs['X'] - w = self.inputs['PaddingData'] + w, _ = self.inputs['PaddingData'] out = self.outputs['Out'] lod = lod[0] + begin_pad = np.max([0, -self.context_start]) for i in range(len(lod) - 1): for j in range(self.context_length): @@ -43,22 +49,20 @@ class TestSeqProject(OpTest): if in_begin < lod[i]: pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]]) if self.padding_trainable: - sub_w = w[j:pad_size, :] + sub_w = w[j:j + pad_size, :] out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:( j + 1) * self.input_size[1]] = sub_w - # pass out_begin = lod[i] + pad_size in_begin = lod[i] if in_end > lod[i + 1]: pad_size = np.min( [in_end - lod[i + 1], lod[i + 1] - lod[i]]) - out_sub = out[lod[i + 1] - pad_size:lod[i + 1], :] if self.padding_trainable: - sub_w = w[j - pad_size:j, :] + sub_w = w[begin_pad + self.context_start + j - pad_size: + begin_pad + self.context_start + j, :] out[lod[i + 1] - pad_size:lod[i + 1], j * self. input_size[1]:(j + 1) * self.input_size[1]] = sub_w - # pass in_end = lod[i + 1] out_end = lod[i + 1] - pad_size if in_end <= in_begin: @@ -69,28 +73,105 @@ class TestSeqProject(OpTest): self.input_size[1]] += in_sub def init_test_case(self): - self.input_size = [11, 23] + self.input_row = 11 + self.input_size = [self.input_row, 23] + self.lod = [[0, 4, 5, 8, self.input_row]] self.op_type = "sequence_project" self.context_start = -1 self.context_length = 3 - self.padding_trainable = False + self.padding_trainable = True + self.context_stride = 1 def test_check_output(self): self.check_output() # def test_check_grad(self): - # self.check_grad(["X"], "Out") + # self.check_grad( + # set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) - # class TestSeqAvgPool2D(TestSeqProject): - # def init_test_case(self): - # self.input_size = [11, 23] - # self.op_type = "sequence_project" + # def test_check_grad_no_filter(self): + # self.check_grad( + # ['X'], + # 'Out', + # max_relative_error=0.05, + # no_grad_set=set(['PaddingData'])) # - # self.context_start = -1 - # self.context_length = 3 - # self.padding_trainable = True + # def test_check_grad_no_input(self): + # self.check_grad( + # ['PaddingData'], + # 'Out', + # max_relative_error=0.05, + # no_grad_set=set(['X'])) + + +''' +class TestSeqProjectCases(TestSeqProject): + def setUp(self): + self.init_test_case() + self.op_type = 'sequence_project' + + num = 0 + for context_start in [-5, -3, -1, 0, 3]: + for context_length in [1, 2, 5, 7]: + for batch_size in [1, 2, 5, 7]: + for padding_trainable in [False, True]: + + if context_length == 1 and context_start == 0 and padding_trainable: + continue + + self.context_start = context_start + self.context_length = context_length + self.padding_trainable = padding_trainable + self.input_size = [batch_size, 23] + x = np.random.uniform(0.1, 1, + self.input_size).astype('float32') + self.lod = [[0, self.input_size[0]]] + if self.input_size[0] > 2: + idx = range(self.input_size[0]) + del idx[0] + self.lod = [ + [0] + np.sort(random.sample(idx, 2)).tolist() + + [self.input_size[0]] + ] + + self.begin_pad = np.max([0, -self.context_start]) + self.end_pad = np.max( + [0, self.context_start + self.context_length - 1]) + self.total_pad = self.begin_pad + self.end_pad + # w = np.ones((self.total_pad, self.input_size[1])) * 100 + w = np.array(range(self.total_pad * self.input_size[1])) + w.shape = self.total_pad, self.input_size[1] + if self.total_pad * self.input_size[1] == 0: + w = np.random.uniform( + 0.1, 1, + (1, self.input_size[1])).astype('float32') + self.total_pad = 1 + + self.inputs = { + 'X': (x, self.lod), + 'PaddingData': (w, [[0, self.total_pad]]) + } + self.attrs = { + 'context_start': self.context_start, + 'context_length': self.context_length, + 'padding_trainable': self.padding_trainable, + 'context_stride': self.context_stride + } + out = np.zeros((self.input_size[0], self.input_size[1] * + self.context_length)).astype('float32') + self.outputs = {'Out': out} + print num + print self.attrs + print batch_size + print padding_trainable + print "$$$$$$$$$$$$$" + + self.compute() + self.test_check_output() + num += 1 +''' if __name__ == '__main__': unittest.main() -- GitLab