diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 4e951f6318cc9cdfb41bfa7db2f38fd0d38a6438..871240aa15fce0212e60e47d7f46861e304ee4ae 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -17,14 +17,15 @@ endfunction() if (WITH_ASCEND_CL) detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu box_coder_op_npu.cc) + detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu density_prior_box_op_npu.cc) else() detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) + detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) endif() detection_library(bipartite_match_op SRCS bipartite_match_op.cc) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) -detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) detection_library(anchor_generator_op SRCS anchor_generator_op.cc anchor_generator_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc diff --git a/paddle/fluid/operators/detection/density_prior_box_op_npu.cc b/paddle/fluid/operators/detection/density_prior_box_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb58640056438bdaf7928d727a15cdc92cee2c65 --- /dev/null +++ b/paddle/fluid/operators/detection/density_prior_box_op_npu.cc @@ -0,0 +1,379 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/density_prior_box_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using fp16 = paddle::platform::float16; + +template +struct DensityPriorBoxFunction { + public: + explicit DensityPriorBoxFunction(const framework::ExecutionContext& ctx) + : ctx(ctx) { + place = ctx.GetPlace(); + stream = ctx.template device_context().stream(); + t0.mutable_data({1}, place); + t1.mutable_data({1}, place); + tn.mutable_data({1}, place); + FillNpuTensorWithConstant(&t0, static_cast(0)); + FillNpuTensorWithConstant(&t1, static_cast(1)); + } + void Arange(int n, Tensor* x) { + // x should be init first + FillNpuTensorWithConstant(&tn, static_cast(n)); + const auto& runner = NpuOpRunner("Range", {t0, tn, t1}, {*x}, {}); + runner.Run(stream); + } + void Add(const Tensor* x, const Tensor* y, Tensor* z) { + // z should be init first + const auto& runner = NpuOpRunner("AddV2", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Cast(const Tensor* x, Tensor* y) { + auto dst_dtype = ConvertToNpuDtype(y->type()); + const auto& runner = NpuOpRunner( + "Cast", {*x}, {*y}, {{"dst_type", static_cast(dst_dtype)}}); + runner.Run(stream); + } + void Sub(const Tensor* x, const Tensor* y, Tensor* z) { + // z should be init first + const auto& runner = NpuOpRunner("Sub", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Mul(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Mul", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Adds(const Tensor* x, float scalar, Tensor* y) { + // y should be init first + const auto& runner = NpuOpRunner("Adds", {*x}, {*y}, {{"value", scalar}}); + runner.Run(stream); + } + void Muls(const Tensor* x, float scalar, Tensor* y) { + // y should be init first + const auto& runner = NpuOpRunner("Muls", {*x}, {*y}, {{"value", scalar}}); + runner.Run(stream); + } + void Maximum(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Minimum(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Minimum", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Concat(const std::vector& inputs, int axis, Tensor* output) { + // output should be init first + std::vector names; + for (size_t i = 0; i < inputs.size(); i++) { + names.push_back("x" + std::to_string(i)); + } + NpuOpRunner runner{ + "ConcatD", + {inputs}, + {*output}, + {{"concat_dim", axis}, {"N", static_cast(inputs.size())}}}; + runner.AddInputNames(names); + runner.Run(stream); + } + void Tile(const Tensor* x, Tensor* y, const std::vector& multiples) { + // y should be init first + if (x->dims() == y->dims()) { + framework::TensorCopy( + *x, place, ctx.template device_context(), + y); + return; + } + const auto& runner = + NpuOpRunner("TileD", {*x}, {*y}, {{"multiples", multiples}}); + runner.Run(stream); + } + void FloatVec2Tsr(const std::vector& vec, Tensor* tsr_dst) { + // + framework::TensorFromVector(vec, ctx.device_context(), tsr_dst); + ctx.template device_context().Wait(); + } + + private: + platform::Place place; + aclrtStream stream; + const framework::ExecutionContext& ctx; + Tensor t0; + Tensor t1; + Tensor tn; +}; + +template <> +void DensityPriorBoxFunction::Arange(int n, Tensor* x) { + Tensor x_fp32(framework::proto::VarType::FP32); + x_fp32.mutable_data(x->dims(), place); + FillNpuTensorWithConstant(&tn, static_cast(n)); + const auto& runner = NpuOpRunner("Range", {t0, tn, t1}, {x_fp32}, {}); + runner.Run(stream); + Cast(&x_fp32, x); +} + +template <> +void DensityPriorBoxFunction::FloatVec2Tsr(const std::vector& vec, + Tensor* tsr_dst) { + Tensor tsr_fp32(framework::proto::VarType::FP32); + tsr_fp32.mutable_data(tsr_dst->dims(), place); + framework::TensorFromVector(vec, ctx.device_context(), &tsr_fp32); + ctx.template device_context().Wait(); + Cast(&tsr_fp32, tsr_dst); +} + +template +class DensityPriorBoxOpNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto variances = ctx.Attr>("variances"); + auto clip = ctx.Attr("clip"); + + auto fixed_sizes = ctx.Attr>("fixed_sizes"); + auto fixed_ratios = ctx.Attr>("fixed_ratios"); + auto densities = ctx.Attr>("densities"); + + float step_w = ctx.Attr("step_w"); + float step_h = ctx.Attr("step_h"); + float offset = ctx.Attr("offset"); + + int image_w = image->dims()[3]; + int image_h = image->dims()[2]; + int layer_w = input->dims()[3]; + int layer_h = input->dims()[2]; + + auto _type = input->type(); + auto place = ctx.GetPlace(); + DensityPriorBoxFunction F(ctx); + + Tensor h(_type); + h.mutable_data({layer_h}, place); + Tensor w(_type); + w.mutable_data({layer_w}, place); + F.Arange(layer_h, &h); + F.Arange(layer_w, &w); + h.Resize({layer_h, 1, 1, 1}); + w.Resize({1, layer_w, 1, 1}); + + step_w = step_w > 0 ? step_w : static_cast(image_w) / layer_w; + step_h = step_h > 0 ? step_h : static_cast(image_h) / layer_h; + int step_average = static_cast((step_w + step_h) * 0.5); + + int ratios_size = fixed_ratios.size(); + int num_priors_per_ratio = 0; + for (size_t i = 0; i < densities.size(); ++i) { + num_priors_per_ratio += densities[i] * densities[i]; + } + Tensor di(_type); + Tensor dj(_type); + Tensor shifts(_type); + Tensor box_w_ratio(_type); + Tensor box_h_ratio(_type); + di.mutable_data({ratios_size * num_priors_per_ratio}, place); + dj.mutable_data({ratios_size * num_priors_per_ratio}, place); + shifts.mutable_data({ratios_size * num_priors_per_ratio}, place); + box_w_ratio.mutable_data({ratios_size * num_priors_per_ratio}, place); + box_h_ratio.mutable_data({ratios_size * num_priors_per_ratio}, place); + + int64_t start = 0; + std::vector vec_tile = {0, 0, 0}; + for (size_t i = 0; i < densities.size(); ++i) { + // Range = start:start+ratios_size*density_sqr, density = densities[i] + int density_sqr = densities[i] * densities[i]; + // shifts[Range] = [step_average/density]*ratios_size*density_sqr + Tensor shifts_part = + shifts.Slice(start, start + ratios_size * density_sqr); + FillNpuTensorWithConstant(&shifts_part, + static_cast(step_average / densities[i])); + + // di[Range] = [ i // density for i in range(density_sqr) ] * ratios_size + // dj[Range] = [ i % density for i in range(density_sqr) ] * ratios_size + Tensor di_part = di.Slice(start, start + ratios_size * density_sqr); + Tensor dj_part = dj.Slice(start, start + ratios_size * density_sqr); + if (densities[i] > 1) { + di_part.Resize({ratios_size, densities[i], densities[i]}); + dj_part.Resize({ratios_size, densities[i], densities[i]}); + Tensor range_n(_type); + range_n.mutable_data({densities[i]}, place); + F.Arange(densities[i], &range_n); + range_n.Resize({1, densities[i], 1}); + vec_tile[0] = ratios_size; + vec_tile[1] = 1; + vec_tile[2] = densities[i]; + F.Tile(&range_n, &di_part, vec_tile); + range_n.Resize({1, 1, densities[i]}); + vec_tile[1] = densities[i]; + vec_tile[2] = 1; + F.Tile(&range_n, &dj_part, vec_tile); + } else { + FillNpuTensorWithConstant(&di_part, static_cast(0)); + FillNpuTensorWithConstant(&dj_part, static_cast(0)); + } + + int start_box_ratio = start; + for (float ar : fixed_ratios) { + // Range_mini = start_box_ratio:start_box_ratio+density_sqr + // box_h_ratio[Range_mini] = [fixed_sizes[i] * sqrt(ar)] * density_sqr + // box_w_ratio[Range_mini] = [fixed_sizes[i] / sqrt(ar)] * density_sqr + Tensor box_h_ratio_part = + box_h_ratio.Slice(start_box_ratio, start_box_ratio + density_sqr); + Tensor box_w_ratio_part = + box_w_ratio.Slice(start_box_ratio, start_box_ratio + density_sqr); + FillNpuTensorWithConstant(&box_w_ratio_part, + static_cast(fixed_sizes[i] * sqrt(ar))); + FillNpuTensorWithConstant(&box_h_ratio_part, + static_cast(fixed_sizes[i] / sqrt(ar))); + start_box_ratio += density_sqr; + } + start = start_box_ratio; + } + di.Resize({1, 1, ratios_size * num_priors_per_ratio, 1}); + dj.Resize({1, 1, ratios_size * num_priors_per_ratio, 1}); + shifts.Resize({1, 1, ratios_size * num_priors_per_ratio, 1}); + box_w_ratio.Resize({1, 1, ratios_size * num_priors_per_ratio, 1}); + box_h_ratio.Resize({1, 1, ratios_size * num_priors_per_ratio, 1}); + + // c_x = (w+offset)*step_w - 0.5*step_average + 0.5*shifts + dj*shifts + // c_y = (h+offset)*step_h - 0.5*step_average + 0.5*shifts + di*shifts + Tensor c_x(_type); + Tensor c_y(_type); + auto dim0 = framework::make_ddim( + {1, layer_w, ratios_size * num_priors_per_ratio, 1}); + auto dim1 = framework::make_ddim( + {layer_h, 1, ratios_size * num_priors_per_ratio, 1}); + c_x.mutable_data(dim0, place); + c_y.mutable_data(dim1, place); + F.Adds(&w, offset, &w); + F.Muls(&w, step_w, &w); + F.Adds(&w, static_cast(-step_average) * static_cast(0.5), &w); + F.Adds(&h, offset, &h); + F.Muls(&h, step_h, &h); + F.Adds(&h, static_cast(-step_average) * static_cast(0.5), &h); + F.Mul(&di, &shifts, &di); + F.Mul(&dj, &shifts, &dj); + F.Muls(&shifts, static_cast(0.5), &shifts); + F.Add(&di, &shifts, &di); + F.Add(&dj, &shifts, &dj); + F.Add(&dj, &w, &c_x); + F.Add(&di, &h, &c_y); + + // box_w_ratio = box_w_ratio / 2 + // box_h_ratio = box_h_ratio / 2 + F.Muls(&box_w_ratio, static_cast(0.5), &box_w_ratio); + F.Muls(&box_h_ratio, static_cast(0.5), &box_h_ratio); + + Tensor zero_t(_type); + Tensor one_t(_type); + zero_t.mutable_data({1}, place); + one_t.mutable_data({1}, place); + FillNpuTensorWithConstant(&zero_t, static_cast(0)); + FillNpuTensorWithConstant(&one_t, static_cast(1)); + + Tensor outbox0(_type); + Tensor outbox1(_type); + Tensor outbox2(_type); + Tensor outbox3(_type); + outbox0.mutable_data(dim0, place); + outbox1.mutable_data(dim1, place); + outbox2.mutable_data(dim0, place); + outbox3.mutable_data(dim1, place); + + // outbox0 = max ( (c_x - box_w_ratio)/image_w, 0 ) + // outbox1 = max ( (c_y - box_h_ratio)/image_h, 0 ) + // outbox2 = min ( (c_x + box_w_ratio)/image_w, 1 ) + // outbox3 = min ( (c_y + box_h_ratio)/image_h, 1 ) + F.Sub(&c_x, &box_w_ratio, &outbox0); + F.Sub(&c_y, &box_h_ratio, &outbox1); + F.Add(&c_x, &box_w_ratio, &outbox2); + F.Add(&c_y, &box_h_ratio, &outbox3); + F.Muls(&outbox0, static_cast(1.0 / image_w), &outbox0); + F.Muls(&outbox1, static_cast(1.0 / image_h), &outbox1); + F.Muls(&outbox2, static_cast(1.0 / image_w), &outbox2); + F.Muls(&outbox3, static_cast(1.0 / image_h), &outbox3); + + F.Maximum(&outbox0, &zero_t, &outbox0); + F.Maximum(&outbox1, &zero_t, &outbox1); + F.Minimum(&outbox2, &one_t, &outbox2); + F.Minimum(&outbox3, &one_t, &outbox3); + if (clip) { + // outbox0 = min ( outbox0, 1 ) + // outbox1 = min ( outbox1, 1 ) + // outbox2 = max ( outbox2, 0 ) + // outbox3 = max ( outbox3, 0 ) + F.Minimum(&outbox0, &one_t, &outbox0); + F.Minimum(&outbox1, &one_t, &outbox1); + F.Maximum(&outbox2, &zero_t, &outbox2); + F.Maximum(&outbox3, &zero_t, &outbox3); + } + + auto out_dim = framework::make_ddim( + {layer_h, layer_w, ratios_size * num_priors_per_ratio, 4}); + boxes->mutable_data(place); + vars->mutable_data(place); + Tensor boxes_share(_type); + Tensor vars_share(_type); + boxes_share.ShareDataWith(*boxes); + boxes_share.Resize(out_dim); + vars_share.ShareDataWith(*vars); + vars_share.Resize(out_dim); + + Tensor box0(_type); + Tensor box1(_type); + Tensor box2(_type); + Tensor box3(_type); + // out_dim = {layer_h, layer_w, ratios_size*num_priors_per_ratio, 1} + out_dim[3] = 1; + box0.mutable_data(out_dim, place); + box1.mutable_data(out_dim, place); + box2.mutable_data(out_dim, place); + box3.mutable_data(out_dim, place); + + std::vector vec_exp_out02 = {layer_h, 1, 1, 1}; + std::vector vec_exp_out13 = {1, layer_w, 1, 1}; + F.Tile(&outbox0, &box0, vec_exp_out02); + F.Tile(&outbox1, &box1, vec_exp_out13); + F.Tile(&outbox2, &box2, vec_exp_out02); + F.Tile(&outbox3, &box3, vec_exp_out13); + F.Concat({box0, box1, box2, box3}, 3, &boxes_share); + + std::vector multiples = {layer_h, layer_w, + ratios_size * num_priors_per_ratio, 1}; + Tensor variances_t(_type); + // variances.size() == 4 + variances_t.mutable_data({4}, place); + F.FloatVec2Tsr(variances, &variances_t); + F.Tile(&variances_t, &vars_share, multiples); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(density_prior_box, + ops::DensityPriorBoxOpNPUKernel, + ops::DensityPriorBoxOpNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_density_prior_box_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_density_prior_box_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..a190aa9b6f2be543720ee83422ce261531f2a212 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_density_prior_box_op_npu.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +import math +import paddle +from op_test import OpTest + +paddle.enable_static() + +np.random.seed(2021) + + +class TestNpuDensityPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + #self.init_test_output2() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'variances': self.variances, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset, + 'densities': self.densities, + 'fixed_sizes': self.fixed_sizes, + 'fixed_ratios': self.fixed_ratios, + 'flatten_to_2d': self.flatten_to_2d + } + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=self.atol) + + def setUp(self): + self.__class__.use_npu = True + self.op_type = 'density_prior_box' + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.set_data() + + def init_dtype(self): + self.dtype = np.float32 + + def set_density(self): + self.densities = [4, 2, 1] + self.fixed_sizes = [32.0, 64.0, 128.0] + self.fixed_ratios = [1.0] + self.layer_w = 17 + self.layer_h = 17 + self.image_w = 533 + self.image_h = 533 + self.flatten_to_2d = False + + def init_test_params(self): + self.set_density() + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + self.num_priors = 0 + if len(self.fixed_sizes) > 0 and len(self.densities) > 0: + for density in self.densities: + if len(self.fixed_ratios) > 0: + self.num_priors += len(self.fixed_ratios) * (pow(density, + 2)) + self.offset = 0.5 + self.atol = 1e-5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_h, + self.image_w)).astype(self.dtype) + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_h, + self.layer_w)).astype(self.dtype) + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype(self.dtype) + out_var = np.zeros(out_dim).astype(self.dtype) + + step_average = int((self.step_w + self.step_h) * 0.5) + for h in range(self.layer_h): + for w in range(self.layer_w): + idx = 0 + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + # Generate density prior boxes with fixed size + for density, fixed_size in zip(self.densities, + self.fixed_sizes): + if (len(self.fixed_ratios) > 0): + for ar in self.fixed_ratios: + shift = int(step_average / density) + box_width_ratio = fixed_size * math.sqrt(ar) + box_height_ratio = fixed_size / math.sqrt(ar) + for di in range(density): + for dj in range(density): + c_x_temp = c_x - step_average / 2.0 + shift / 2.0 + dj * shift + c_y_temp = c_y - step_average / 2.0 + shift / 2.0 + di * shift + out_boxes[h, w, idx, :] = [ + max((c_x_temp - box_width_ratio / 2.0) / + self.image_w, 0), + max((c_y_temp - box_height_ratio / 2.0) + / self.image_h, 0), + min((c_x_temp + box_width_ratio / 2.0) / + self.image_w, 1), + min((c_y_temp + box_height_ratio / 2.0) + / self.image_h, 1) + ] + idx += 1 + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + out_var = np.tile(self.variances, + (self.layer_h, self.layer_w, self.num_priors, 1)) + self.out_boxes = out_boxes.astype(self.dtype) + self.out_var = out_var.astype(self.dtype) + if self.flatten_to_2d: + self.out_boxes = self.out_boxes.reshape((-1, 4)) + self.out_var = self.out_var.reshape((-1, 4)) + + +class TestNpuDensityPriorBoxFlatten(TestNpuDensityPriorBoxOp): + def set_density(self): + self.densities = [3, 4] + self.fixed_sizes = [1.0, 2.0] + self.fixed_ratios = [1.0] + self.layer_w = 32 + self.layer_h = 32 + self.image_w = 40 + self.image_h = 40 + self.flatten_to_2d = True + + +class TestNpuDensityPriorBoxOp1(TestNpuDensityPriorBoxOp): + def set_density(self): + super(TestNpuDensityPriorBoxOp1, self).set_density() + self.layer_w = 1 + self.layer_h = 1 + + +class TestNpuDensityPriorBoxOp2(TestNpuDensityPriorBoxOp): + def set_density(self): + super(TestNpuDensityPriorBoxOp2, self).set_density() + self.layer_w = 15 + self.layer_h = 17 + self.image_w = 533 + self.image_h = 532 + + +class TestNpuDensityPriorBoxOp3(TestNpuDensityPriorBoxOp): + def set_density(self): + super(TestNpuDensityPriorBoxOp3, self).set_density() + self.fixed_ratios = [1.0, 4.0] + + +class TestNpuDensityPriorBoxOpFP16(TestNpuDensityPriorBoxOp): + def init_dtype(self): + self.dtype = np.float16 + + def init_test_params(self): + super(TestNpuDensityPriorBoxOpFP16, self).init_test_params() + self.atol = 1e-3 + self.clip = False + + +if __name__ == '__main__': + unittest.main()