未验证 提交 a9dfebb9 编写于 作者: H heliqi 提交者: GitHub

[NPU]add conv2d_transpose npu op (#35232)

* add conv2d_transpose npu op

* CopyRight 2020 to 2021

* add fp32

* delete repeat test case

* delete repeat test case

* fix paddle.NPUPlace
上级 8305ba37
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename T>
class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// input
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* filter = context.Input<Tensor>("Filter");
// output
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
// attr
std::vector<int> output_padding =
context.Attr<std::vector<int>>("output_padding");
const std::vector<int> stride = context.Attr<std::vector<int>>("strides");
std::vector<int> padding = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = context.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format");
int groups = context.Attr<int>("groups");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
// npu stream
auto stream =
context.template device_context<platform::NPUDeviceContext>().stream();
// check dimension
const bool channel_last = data_format == "NHWC";
// update padding and dilation
auto in_dims = input->dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
framework::DDim filter_data_dims;
if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
}
filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize);
// construct NPU attr
std::vector<int> strides(4, 1);
std::vector<int> dilations(4, 1);
Tensor input_tensor, output_tensor;
input_tensor.ShareDataWith(*input);
output_tensor.ShareDataWith(*output);
if (channel_last) {
input_tensor.set_layout(DataLayout::kNHWC);
output_tensor.set_layout(DataLayout::kNHWC);
strides[1] = stride[0];
strides[2] = stride[1];
dilations[1] = dilation[0];
dilations[2] = dilation[1];
} else {
strides[2] = stride[0];
strides[3] = stride[1];
dilations[2] = dilation[0];
dilations[3] = dilation[1];
}
for (auto i = output_padding.size(); i < 4; ++i) {
output_padding.insert(output_padding.begin(), 0);
}
auto output_dim_vec = framework::vectorize(output_tensor.dims());
// CANN OP
const auto& runner =
NpuOpRunner("Conv2DTransposeD", {input_tensor, *filter},
{output_tensor}, {{"input_size", output_dim_vec},
{"strides", strides},
{"dilations", dilations},
{"output_padding", output_padding},
{"groups", groups},
{"pads", padding},
{"data_format", data_format}});
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
// conv2d
REGISTER_OP_NPU_KERNEL(conv2d_transpose, ops::Conv2DTransposeNPUKernel<float>,
ops::Conv2DTransposeNPUKernel<plat::float16>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from test_conv2d_transpose_op import conv2dtranspose_forward_naive
paddle.enable_static()
@skip_check_grad_ci(
reason='''Inference only, it doesn't need to call check_grad.''')
class TestConv2DTransposeOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float16
def init_data_format(self):
self.data_format = "NCHW"
def setUp(self):
self.init_op_type()
self.init_dtype()
self.set_npu()
self.init_data_format()
self.output_padding = []
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_test_case()
self.output_size = None
input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': False,
'is_test': False,
'use_mkldnn': False,
'data_format': self.data_format
}
if self.output_size is not None:
self.attrs['output_size'] = self.output_size
if len(self.output_padding) > 0:
self.attrs['output_padding'] = self.output_padding
output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype(self.dtype)
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2)
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose"
class TestWithSymmetricPad_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self):
self.dtype = np.float32
class TestWithSymmetricPad(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithAsymmetricPad_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self):
self.dtype = np.float32
class TestWithAsymmetricPad(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithSAMEPad_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [2, 3, 6, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.padding_algorithm = 'SAME'
def init_dtype(self):
self.dtype = np.float32
class TestWithSAMEPad(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [2, 3, 6, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.padding_algorithm = 'SAME'
class TestWithVALIDPad_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'VALID'
def init_dtype(self):
self.dtype = np.float32
class TestWithVALIDPad(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'VALID'
class TestWithGroups_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
def init_dtype(self):
self.dtype = np.float32
class TestWithGroups(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
class TestWithStride_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self):
self.dtype = np.float32
class TestWithStride(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithDilation_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self):
self.dtype = np.float32
class TestWithDilation(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithEvenUpsample_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
def init_dtype(self):
self.dtype = np.float32
class TestWithEvenUpsample(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
class TestWithEvenUpsampleOutputPadding_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
def init_dtype(self):
self.dtype = np.float32
class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
class Test_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class Test_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithSymmetricPad_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithAsymmetricPad_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithGroups_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithGroups_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithStride_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithStride_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithDilation_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithDilation_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_output_padding_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
def init_dtype(self):
self.dtype = np.float32
class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestConv2DTransposeAPI(unittest.TestCase):
def test_case1(self):
data1 = fluid.layers.data(
name='data1', shape=[3, 5, 5], dtype='float32')
data2 = fluid.layers.data(
name='data2', shape=[5, 5, 3], dtype='float32')
out1 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
data_format='NCHW')
out2 = fluid.layers.conv2d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
data_format='NHWC')
out3 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
data_format='NHWC')
out4 = fluid.layers.conv2d_transpose(
input=data1,
groups=3,
num_filters=6,
filter_size=3,
padding=[[0, 0], [0, 0], [2, 1], [0, 0]],
data_format='NCHW')
out5 = fluid.layers.conv2d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
padding='SAME',
data_format='NCHW')
out6 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding='VALID',
data_format='NHWC')
out7 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
output_size=[7, 7],
padding=[0, 0],
data_format='NHWC')
data1_np = np.random.random((2, 3, 5, 5)).astype("float32")
data2_np = np.random.random((2, 5, 5, 3)).astype("float32")
place = core.NPUPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(
fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2, out3, out4, out5, out6, out7],
return_numpy=True)
self.assertIsNotNone(results[0])
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[2])
self.assertIsNotNone(results[3])
self.assertIsNotNone(results[4])
self.assertIsNotNone(results[5])
self.assertIsNotNone(results[6])
class TestConv2DTransposeOpException(unittest.TestCase):
def test_exception(self):
data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32")
def attr_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
data_format="NCDHW")
self.assertRaises(ValueError, attr_data_format)
def attr_padding_str():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding='Vald')
self.assertRaises(ValueError, attr_padding_str)
def attr_padding_list():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [1, 1], [0, 0], [0, 0]])
self.assertRaises(ValueError, attr_padding_list)
def attr_padding_with_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [0, 0], [0, 0], [1, 1]],
data_format='NHWC')
self.assertRaises(ValueError, attr_padding_with_data_format)
error_input = fluid.layers.data(
name='error_data', shape=[1], dtype="float32")
def error_input_size():
out = fluid.layers.conv2d_transpose(
input=error_input, groups=1, num_filters=6, filter_size=3)
self.assertRaises(ValueError, error_input_size)
def error_groups():
out = fluid.layers.conv2d_transpose(
input=data,
groups=0,
num_filters=6,
filter_size=3,
data_format='NHWC')
self.assertRaises(ValueError, error_groups)
class TestConv2DTransposeRepr(unittest.TestCase):
def test_case(self):
paddle.disable_static(paddle.NPUPlace(0))
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = nn.Conv2DTranspose(4, 6, (3, 3), output_padding=1, stride=2)
print(conv)
y_var = conv(x_var)
y_np = y_var.numpy()
self.assertIsNotNone(y_np)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册