未验证 提交 b97af7d0 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] conv2d_transpose_grad, test=develop (#35553)

* [NPU] conv2d_transpose_grad, test=develop

* remove debug files, test=develop

* fix bug, test=develop

* [NPU] fix test_conv2d_transpose_op_npu, test=develop
上级 8528dd9f
...@@ -145,7 +145,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -145,7 +145,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"output_size of Op(ConvTransposeOp) should not be " "output_size of Op(ConvTransposeOp) should not be "
"less than the infered output size. But received output_size = " "less than the infered output size. But received output_size = "
"[%s], whose dim %d is less than the infered output size [%s]", "[%s], whose dim %d is less than the infered output size [%s]",
framework::make_ddim(output_size), i, infer_shape)); framework::make_ddim(output_size).to_str(), i, infer_shape));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
output_size[i], infer_shape + strides[i], output_size[i], infer_shape + strides[i],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -153,8 +153,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -153,8 +153,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"than infered size + stride. But received output_size = [%s], " "than infered size + stride. But received output_size = [%s], "
"whose dim %d is not less than the infered output size (%d) + " "whose dim %d is not less than the infered output size (%d) + "
"stride (%d) = %d", "stride (%d) = %d",
framework::make_ddim(output_size), i, infer_shape, strides[i], framework::make_ddim(output_size).to_str(), i, infer_shape,
infer_shape + strides[i])); strides[i], infer_shape + strides[i]));
} }
output_shape.push_back(output_size[i]); output_shape.push_back(output_size[i]);
} else if (output_padding.size()) { } else if (output_padding.size()) {
...@@ -165,7 +165,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -165,7 +165,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"output_padding of Op(ConvTransposeOp) should not be " "output_padding of Op(ConvTransposeOp) should not be "
"less than the 0. But received output_padding = " "less than the 0. But received output_padding = "
"[%s], whose dim %d is less than 0", "[%s], whose dim %d is less than 0",
framework::make_ddim(output_padding), i)); framework::make_ddim(output_padding).to_str(), i));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
output_padding[i], std::max(strides[i], dilations[i]), output_padding[i], std::max(strides[i], dilations[i]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -174,7 +174,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -174,7 +174,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"[%s], " "[%s], "
"whose dim %d is not less than either stride (%d) or " "whose dim %d is not less than either stride (%d) or "
"dilation (%d)", "dilation (%d)",
framework::make_ddim(output_size), i, strides[i], framework::make_ddim(output_size).to_str(), i, strides[i],
dilations[i])); dilations[i]));
} }
output_shape.push_back((infer_shape + output_padding[i])); output_shape.push_back((infer_shape + output_padding[i]));
......
...@@ -18,30 +18,25 @@ limitations under the License. */ ...@@ -18,30 +18,25 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T> template <typename T>
class Conv2DTransposeNPUKernel : public framework::OpKernel<T> { class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// input const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* filter = context.Input<Tensor>("Filter"); Tensor* output = ctx.Output<Tensor>("Output");
// output output->mutable_data<T>(ctx.GetPlace());
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
// attr
std::vector<int> output_padding = std::vector<int> output_padding =
context.Attr<std::vector<int>>("output_padding"); ctx.Attr<std::vector<int>>("output_padding");
const std::vector<int> stride = context.Attr<std::vector<int>>("strides"); const std::vector<int> stride = ctx.Attr<std::vector<int>>("strides");
std::vector<int> padding = context.Attr<std::vector<int>>("paddings"); std::vector<int> padding = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilation = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilation = ctx.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
int groups = context.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm = const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm"); ctx.Attr<std::string>("padding_algorithm");
// npu stream
auto stream =
context.template device_context<platform::NPUDeviceContext>().stream();
// check dimension // check dimension
const bool channel_last = data_format == "NHWC"; const bool channel_last = data_format == "NHWC";
...@@ -89,7 +84,8 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> { ...@@ -89,7 +84,8 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
output_padding.insert(output_padding.begin(), 0); output_padding.insert(output_padding.begin(), 0);
} }
auto output_dim_vec = framework::vectorize(output_tensor.dims()); auto output_dim_vec = framework::vectorize(output_tensor.dims());
// CANN OP
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
const auto& runner = const auto& runner =
NpuOpRunner("Conv2DTransposeD", {input_tensor, *filter}, NpuOpRunner("Conv2DTransposeD", {input_tensor, *filter},
{output_tensor}, {{"input_size", output_dim_vec}, {output_tensor}, {{"input_size", output_dim_vec},
...@@ -103,12 +99,109 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> { ...@@ -103,12 +99,109 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class Conv2DTransposeGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
if ((!input_grad) && (!filter_grad)) return;
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
auto in_dims = input->dims();
auto filter_dims = filter->dims();
// auto out_grad_dims = output_grad->dims();
// const int batch_size = static_cast<int>(input->dims()[0]);
const bool channel_last = (data_layout == framework::DataLayout::kNHWC);
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int> strides_vec(4, 1);
std::vector<int> dilations_vec(4, 1);
Tensor input_tensor, output_grad_tensor;
input_tensor.ShareDataWith(*input);
output_grad_tensor.ShareDataWith(*output_grad);
if (channel_last) {
input_tensor.set_layout(DataLayout::kNHWC);
output_grad_tensor.set_layout(DataLayout::kNHWC);
strides_vec[1] = strides[0];
strides_vec[2] = strides[1];
dilations_vec[1] = dilations[0];
dilations_vec[2] = dilations[1];
} else {
strides_vec[2] = strides[0];
strides_vec[3] = strides[1];
dilations_vec[2] = dilations[0];
dilations_vec[3] = dilations[1];
}
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Conv2DBackpropFilterD",
{output_grad_tensor, input_tensor}, {*filter_grad},
{{"filter_size", framework::vectorize<int>(filter_dims)},
{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
}
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
Tensor input_grad_tensor;
input_grad_tensor.ShareDataWith(*input_grad);
if (channel_last) {
input_grad_tensor.set_layout(DataLayout::kNHWC);
}
const auto& runner =
NpuOpRunner("Conv2D", {output_grad_tensor, *filter},
{input_grad_tensor}, {{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
// conv2d
REGISTER_OP_NPU_KERNEL(conv2d_transpose, ops::Conv2DTransposeNPUKernel<float>, REGISTER_OP_NPU_KERNEL(conv2d_transpose, ops::Conv2DTransposeNPUKernel<float>,
ops::Conv2DTransposeNPUKernel<plat::float16>); ops::Conv2DTransposeNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(conv2d_transpose_grad,
ops::Conv2DTransposeGradNPUKernel<float>,
ops::Conv2DTransposeGradNPUKernel<plat::float16>);
...@@ -18,4 +18,6 @@ if (WITH_ASCEND_CL) ...@@ -18,4 +18,6 @@ if (WITH_ASCEND_CL)
set_tests_properties(test_nearest_interp_op_npu PROPERTIES TIMEOUT 200) set_tests_properties(test_nearest_interp_op_npu PROPERTIES TIMEOUT 200)
set_tests_properties(test_nearest_interp_v2_op_npu PROPERTIES TIMEOUT 200) set_tests_properties(test_nearest_interp_v2_op_npu PROPERTIES TIMEOUT 200)
set_tests_properties(test_stack_op_npu PROPERTIES TIMEOUT 300) set_tests_properties(test_stack_op_npu PROPERTIES TIMEOUT 300)
set_tests_properties(test_conv2d_transpose_op_npu PROPERTIES TIMEOUT 200)
set_tests_properties(test_conv2d_op_npu PROPERTIES TIMEOUT 300)
endif() endif()
...@@ -15,42 +15,38 @@ from __future__ import print_function ...@@ -15,42 +15,38 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import sys
sys.path.append("..")
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci import sys
sys.path.append("..")
from op_test import OpTest
from test_conv2d_transpose_op import conv2dtranspose_forward_naive from test_conv2d_transpose_op import conv2dtranspose_forward_naive
paddle.enable_static() paddle.enable_static()
@skip_check_grad_ci(
reason='''Inference only, it doesn't need to call check_grad.''')
class TestConv2DTransposeOp(OpTest): class TestConv2DTransposeOp(OpTest):
def set_npu(self): def set_npu(self):
self.__class__.use_npu = True self.__class__.use_npu = True
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float16
def init_data_format(self):
self.data_format = "NCHW"
def setUp(self): def setUp(self):
self.init_op_type() # init as conv transpose
self.init_dtype()
self.set_npu() self.set_npu()
self.init_data_format() self.dtype = np.float32
self.need_check_grad = True
self.is_test = False
self.output_size = None
self.output_padding = [] self.output_padding = []
self.data_format = "NCHW"
self.pad = [0, 0] self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_test_case() self.init_test_case()
self.output_size = None self.init_dtype()
input_ = np.random.random(self.input_size).astype(self.dtype) input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype) filter_ = np.random.random(self.filter_size).astype(self.dtype)
...@@ -80,6 +76,32 @@ class TestConv2DTransposeOp(OpTest): ...@@ -80,6 +76,32 @@ class TestConv2DTransposeOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2) self.check_output_with_place(self.place, atol=1e-2)
def test_check_grad_no_input(self):
if self.need_check_grad:
self.check_grad_with_place(
self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']),
numeric_place=paddle.CPUPlace())
def test_check_grad_no_filter(self):
if self.need_check_grad:
self.check_grad_with_place(
self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']),
max_relative_error=0.006,
numeric_place=paddle.CPUPlace())
def test_check_grad(self):
if self.need_check_grad:
self.check_grad_with_place(
self.place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02,
numeric_place=paddle.CPUPlace())
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
...@@ -92,17 +114,6 @@ class TestConv2DTransposeOp(OpTest): ...@@ -92,17 +114,6 @@ class TestConv2DTransposeOp(OpTest):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
class TestWithSymmetricPad_FP32(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
...@@ -118,18 +129,10 @@ class TestWithSymmetricPad(TestConv2DTransposeOp): ...@@ -118,18 +129,10 @@ class TestWithSymmetricPad(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
class TestWithAsymmetricPad_FP32(TestConv2DTransposeOp): class TestWithSymmetricPad_FP16(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithAsymmetricPad(TestConv2DTransposeOp): class TestWithAsymmetricPad(TestConv2DTransposeOp):
...@@ -143,18 +146,10 @@ class TestWithAsymmetricPad(TestConv2DTransposeOp): ...@@ -143,18 +146,10 @@ class TestWithAsymmetricPad(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
class TestWithSAMEPad_FP32(TestConv2DTransposeOp): class TestWithAsymmetricPad_FP16(TestWithAsymmetricPad):
def init_test_case(self):
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [2, 3, 6, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.padding_algorithm = 'SAME'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithSAMEPad(TestConv2DTransposeOp): class TestWithSAMEPad(TestConv2DTransposeOp):
...@@ -168,18 +163,10 @@ class TestWithSAMEPad(TestConv2DTransposeOp): ...@@ -168,18 +163,10 @@ class TestWithSAMEPad(TestConv2DTransposeOp):
self.padding_algorithm = 'SAME' self.padding_algorithm = 'SAME'
class TestWithVALIDPad_FP32(TestConv2DTransposeOp): class TestWithSAMEPad_FP16(TestWithSAMEPad):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'VALID'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithVALIDPad(TestConv2DTransposeOp): class TestWithVALIDPad(TestConv2DTransposeOp):
...@@ -193,18 +180,10 @@ class TestWithVALIDPad(TestConv2DTransposeOp): ...@@ -193,18 +180,10 @@ class TestWithVALIDPad(TestConv2DTransposeOp):
self.padding_algorithm = 'VALID' self.padding_algorithm = 'VALID'
class TestWithGroups_FP32(TestConv2DTransposeOp): class TestWithVALIDPad_FP16(TestWithVALIDPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithGroups(TestConv2DTransposeOp): class TestWithGroups(TestConv2DTransposeOp):
...@@ -218,18 +197,10 @@ class TestWithGroups(TestConv2DTransposeOp): ...@@ -218,18 +197,10 @@ class TestWithGroups(TestConv2DTransposeOp):
self.filter_size = [f_c, 3, 3, 3] self.filter_size = [f_c, 3, 3, 3]
class TestWithStride_FP32(TestConv2DTransposeOp): class TestWithGroups_FP16(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithStride(TestConv2DTransposeOp): class TestWithStride(TestConv2DTransposeOp):
...@@ -243,18 +214,10 @@ class TestWithStride(TestConv2DTransposeOp): ...@@ -243,18 +214,10 @@ class TestWithStride(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
class TestWithDilation_FP32(TestConv2DTransposeOp): class TestWithStride_FP16(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithDilation(TestConv2DTransposeOp): class TestWithDilation(TestConv2DTransposeOp):
...@@ -268,19 +231,10 @@ class TestWithDilation(TestConv2DTransposeOp): ...@@ -268,19 +231,10 @@ class TestWithDilation(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
class TestWithEvenUpsample_FP32(TestConv2DTransposeOp): class TestWithDilation_FP16(TestWithDilation):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithEvenUpsample(TestConv2DTransposeOp): class TestWithEvenUpsample(TestConv2DTransposeOp):
...@@ -295,19 +249,10 @@ class TestWithEvenUpsample(TestConv2DTransposeOp): ...@@ -295,19 +249,10 @@ class TestWithEvenUpsample(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 5, 5] self.filter_size = [f_c, 6, 5, 5]
class TestWithEvenUpsampleOutputPadding_FP32(TestConv2DTransposeOp): class TestWithEvenUpsample_FP16(TestWithEvenUpsample):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp): class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp):
...@@ -322,19 +267,10 @@ class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp): ...@@ -322,19 +267,10 @@ class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp):
self.filter_size = [f_c, 6, 5, 5] self.filter_size = [f_c, 6, 5, 5]
class Test_NHWC_FP32(TestConv2DTransposeOp): class TestWithEvenUpsampleOutputPadding_FP16(TestWithEvenUpsampleOutputPadding):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class Test_NHWC(TestConv2DTransposeOp): class Test_NHWC(TestConv2DTransposeOp):
...@@ -349,19 +285,10 @@ class Test_NHWC(TestConv2DTransposeOp): ...@@ -349,19 +285,10 @@ class Test_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithSymmetricPad_NHWC_FP32(TestConv2DTransposeOp): class Test_NHWC_FP16(Test_NHWC):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp): class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp):
...@@ -376,19 +303,10 @@ class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp): ...@@ -376,19 +303,10 @@ class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithAsymmetricPad_NHWC_FP32(TestConv2DTransposeOp): class TestWithSymmetricPad_NHWC_FP16(TestWithSymmetricPad_NHWC):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp): class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp):
...@@ -403,19 +321,10 @@ class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp): ...@@ -403,19 +321,10 @@ class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithGroups_NHWC_FP32(TestConv2DTransposeOp): class TestWithAsymmetricPad_NHWC_FP16(TestWithAsymmetricPad_NHWC):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithGroups_NHWC(TestConv2DTransposeOp): class TestWithGroups_NHWC(TestConv2DTransposeOp):
...@@ -430,19 +339,10 @@ class TestWithGroups_NHWC(TestConv2DTransposeOp): ...@@ -430,19 +339,10 @@ class TestWithGroups_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithStride_NHWC_FP32(TestConv2DTransposeOp): class TestWithGroups_NHWC_FP16(TestWithGroups_NHWC):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithStride_NHWC(TestConv2DTransposeOp): class TestWithStride_NHWC(TestConv2DTransposeOp):
...@@ -457,19 +357,10 @@ class TestWithStride_NHWC(TestConv2DTransposeOp): ...@@ -457,19 +357,10 @@ class TestWithStride_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithDilation_NHWC_FP32(TestConv2DTransposeOp): class TestWithStride_NHWC_FP16(TestWithStride_NHWC):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithDilation_NHWC(TestConv2DTransposeOp): class TestWithDilation_NHWC(TestConv2DTransposeOp):
...@@ -484,20 +375,10 @@ class TestWithDilation_NHWC(TestConv2DTransposeOp): ...@@ -484,20 +375,10 @@ class TestWithDilation_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_FP32(TestConv2DTransposeOp): class TestWithDilation_NHWC_FP16(TestWithDilation_NHWC):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp): class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp):
...@@ -513,20 +394,10 @@ class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp): ...@@ -513,20 +394,10 @@ class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_output_padding_FP32(TestConv2DTransposeOp): class TestWithEvenUpsample_NHWC_FP16(TestWithEvenUpsample_NHWC):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float16
self.need_check_grad = False
class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp): class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp):
...@@ -542,6 +413,13 @@ class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp): ...@@ -542,6 +413,13 @@ class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp):
self.data_format = 'NHWC' self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_output_padding_FP16(
TestWithEvenUpsample_NHWC_output_padding):
def init_dtype(self):
self.dtype = np.float16
self.need_check_grad = False
class TestConv2DTransposeAPI(unittest.TestCase): class TestConv2DTransposeAPI(unittest.TestCase):
def test_case1(self): def test_case1(self):
data1 = fluid.layers.data( data1 = fluid.layers.data(
...@@ -617,71 +495,6 @@ class TestConv2DTransposeAPI(unittest.TestCase): ...@@ -617,71 +495,6 @@ class TestConv2DTransposeAPI(unittest.TestCase):
self.assertIsNotNone(results[6]) self.assertIsNotNone(results[6])
class TestConv2DTransposeOpException(unittest.TestCase):
def test_exception(self):
data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32")
def attr_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
data_format="NCDHW")
self.assertRaises(ValueError, attr_data_format)
def attr_padding_str():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding='Vald')
self.assertRaises(ValueError, attr_padding_str)
def attr_padding_list():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [1, 1], [0, 0], [0, 0]])
self.assertRaises(ValueError, attr_padding_list)
def attr_padding_with_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [0, 0], [0, 0], [1, 1]],
data_format='NHWC')
self.assertRaises(ValueError, attr_padding_with_data_format)
error_input = fluid.layers.data(
name='error_data', shape=[1], dtype="float32")
def error_input_size():
out = fluid.layers.conv2d_transpose(
input=error_input, groups=1, num_filters=6, filter_size=3)
self.assertRaises(ValueError, error_input_size)
def error_groups():
out = fluid.layers.conv2d_transpose(
input=data,
groups=0,
num_filters=6,
filter_size=3,
data_format='NHWC')
self.assertRaises(ValueError, error_groups)
class TestConv2DTransposeRepr(unittest.TestCase): class TestConv2DTransposeRepr(unittest.TestCase):
def test_case(self): def test_case(self):
paddle.disable_static(paddle.NPUPlace(0)) paddle.disable_static(paddle.NPUPlace(0))
......
...@@ -1435,7 +1435,8 @@ class OpTest(unittest.TestCase): ...@@ -1435,7 +1435,8 @@ class OpTest(unittest.TestCase):
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None, user_defined_grads=None,
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
check_dygraph=True): check_dygraph=True,
numeric_place=None):
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict()
...@@ -1492,9 +1493,12 @@ class OpTest(unittest.TestCase): ...@@ -1492,9 +1493,12 @@ class OpTest(unittest.TestCase):
if not type(output_names) is list: if not type(output_names) is list:
output_names = [output_names] output_names = [output_names]
if numeric_place is None:
numeric_place = place
numeric_grads = user_defined_grads or [ numeric_grads = user_defined_grads or [
get_numeric_gradient( get_numeric_gradient(
place, numeric_place,
self.scope, self.scope,
self.op, self.op,
self.inputs, self.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册