未验证 提交 c96f7a29 编写于 作者: C cambriconhsq 提交者: GitHub

[MLU]add mlu kernel for conv2dtransposed op (#43233)

上级 c49f35cf
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
class Conv2DTransposeMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
std::vector<int> output_padding =
ctx.Attr<std::vector<int>>("output_padding");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string data_format = ctx.Attr<std::string>("data_format");
int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
// check dimension
const bool channel_last = data_format == "NHWC";
auto in_dims = input->dims();
auto filter_dims = filter->dims();
auto in_dims_size = in_dims.size();
framework::DDim in_data_dims;
framework::DDim filter_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
Tensor input_tensor(input->type());
Tensor output_tensor(output->type());
input_tensor.set_layout(DataLayout::kNHWC);
output_tensor.set_layout(DataLayout::kNHWC);
const std::vector<int> perm_to_nhwc = {0, 2, 3, 1};
if (channel_last) {
input_tensor.ShareDataWith(*input);
output_tensor.ShareDataWith(*output);
} else {
// transpose input from NCHW to NHWC
TransposeFromMLUTensor<T>(ctx, perm_to_nhwc, input, &input_tensor,
true /*need_reshape_or_alloc*/);
auto output_dims = output->dims();
output_tensor.mutable_data<T>(
{output_dims[0], output_dims[2], output_dims[3], output_dims[1]},
ctx.GetPlace());
}
// transpose filter from MCHW to MHWC
Tensor trans_filter(filter->type());
TransposeFromMLUTensor<T>(ctx, perm_to_nhwc, filter, &trans_filter,
true /*need_reshape_or_alloc*/);
// construct MLU attr
cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC;
MLUCnnlTensorDesc input_desc(input_tensor, data_layout,
ToCnnlDataType(input_tensor.dtype()));
MLUCnnlTensorDesc filter_desc(trans_filter, data_layout,
ToCnnlDataType(trans_filter.type()));
MLUCnnlTensorDesc output_desc(output_tensor, data_layout,
ToCnnlDataType(output_tensor.dtype()));
MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(),
strides.data(), dilations.data(), groups,
ToCnnlDataType<T>());
MLUCnnl::ConvBackpropInput(ctx, conv_desc.get(), filter_desc.get(),
GetBasePtr(&trans_filter), input_desc.get(),
GetBasePtr(&input_tensor), output_desc.get(),
GetBasePtr(&output_tensor));
if (!channel_last) {
// transpose output from NHWC to NCHW
const std::vector<int> perm_to_nchw = {0, 3, 1, 2};
TransposeFromMLUTensor<T>(ctx, perm_to_nchw, &output_tensor, output,
false /*need_reshape_or_alloc*/);
}
}
};
template <typename T>
class Conv2DTransposeGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
if ((!input_grad) && (!filter_grad)) return;
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
auto in_dims = input->dims();
auto filter_dims = filter->dims();
auto in_dims_size = in_dims.size();
const bool channel_last = (data_layout == framework::DataLayout::kNHWC);
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
Tensor input_tensor(input->type());
Tensor output_grad_tensor(output_grad->type());
output_grad_tensor.set_layout(DataLayout::kNHWC);
const std::vector<int> perm_to_nhwc = {0, 2, 3, 1};
if (channel_last) {
input_tensor.ShareDataWith(*input);
output_grad_tensor.ShareDataWith(*output_grad);
} else {
// transpose input from NCHW to NHWC
TransposeFromMLUTensor<T>(ctx, perm_to_nhwc, input, &input_tensor,
true /*need_reshape_or_alloc*/);
TransposeFromMLUTensor<T>(ctx, perm_to_nhwc, output_grad,
&output_grad_tensor,
true /*need_reshape_or_alloc*/);
}
// transpose filter from MCHW to MHWC
Tensor trans_filter(filter->type());
TransposeFromMLUTensor<T>(ctx, perm_to_nhwc, filter, &trans_filter,
true /*need_reshape_or_alloc*/);
// MLU descs
cnnlTensorLayout_t data_layout_mlu = CNNL_LAYOUT_NHWC;
MLUCnnlTensorDesc input_desc(input_tensor, data_layout_mlu,
ToCnnlDataType(input_tensor.dtype()));
MLUCnnlTensorDesc trans_filter_desc(trans_filter, data_layout_mlu,
ToCnnlDataType(trans_filter.type()));
MLUCnnlTensorDesc output_grad_desc(
output_grad_tensor, data_layout_mlu,
ToCnnlDataType(output_grad_tensor.dtype()));
MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(),
strides.data(), dilations.data(), groups,
ToCnnlDataType<T>());
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
Tensor filter_grad_tensor(filter_grad->type());
// filter_grad always MCHW
// filter_grad_tensor always MHWC
auto filter_grad_dims = filter_grad->dims();
filter_grad_tensor.mutable_data<T>(
{filter_grad_dims[0], filter_grad_dims[2], filter_grad_dims[3],
filter_grad_dims[1]},
ctx.GetPlace());
//}
filter_grad_tensor.set_layout(DataLayout::kNHWC);
MLUCnnlTensorDesc filter_grad_desc(
filter_grad_tensor, data_layout_mlu,
ToCnnlDataType(filter_grad_tensor.dtype()));
MLUCnnl::ConvBackpropFilter(
ctx, conv_desc.get(), output_grad_desc.get(), GetBasePtr(output_grad),
input_desc.get(), GetBasePtr(&input_tensor), filter_grad_desc.get(),
GetBasePtr(&filter_grad_tensor));
// transpose output from MHWC to MCHW
const std::vector<int> perm_to_mchw = {0, 3, 1, 2};
TransposeFromMLUTensor<T>(ctx, perm_to_mchw, &filter_grad_tensor,
filter_grad, false /*need_reshape_or_alloc*/);
}
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
Tensor input_grad_tensor(input_grad->type());
input_tensor.set_layout(DataLayout::kNHWC);
if (channel_last) {
input_grad_tensor.ShareDataWith(*input_grad);
} else {
auto input_grad_dims = input_grad->dims();
input_grad_tensor.mutable_data<T>(
{input_grad_dims[0], input_grad_dims[2], input_grad_dims[3],
input_grad_dims[1]},
ctx.GetPlace());
}
MLUCnnlTensorDesc input_grad_desc(
input_grad_tensor, data_layout_mlu,
ToCnnlDataType(input_grad_tensor.dtype()));
cnnlDataType_t tensor_dtype = ToCnnlDataType<T>();
cnnlDataType_t dt_onchip = ToCnnlDataType<T>();
MLUCnnl::Conv2D(ctx, conv_desc.get(), tensor_dtype, dt_onchip,
nullptr /* input_position */, nullptr /* input_scale */,
nullptr /* input_offset */, nullptr /* filter_position */,
nullptr /* filter_scale */, nullptr /* filter_offset */,
output_grad_desc.get(), GetBasePtr(&output_grad_tensor),
trans_filter_desc.get(), GetBasePtr(&trans_filter),
nullptr /* bias_desc*/, nullptr /* bias */,
input_grad_desc.get(), GetBasePtr(&input_grad_tensor));
if (!channel_last) {
// transpose output from NHWC to NCHW
const std::vector<int> perm_to_nchw = {0, 3, 1, 2};
TransposeFromMLUTensor<T>(ctx, perm_to_nchw, &input_grad_tensor,
input_grad, false /*need_reshape_or_alloc*/);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(conv2d_transpose, ops::Conv2DTransposeMLUKernel<float>,
ops::Conv2DTransposeMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(conv2d_transpose_grad,
ops::Conv2DTransposeGradMLUKernel<float>,
ops::Conv2DTransposeGradMLUKernel<plat::float16>);
......@@ -1159,7 +1159,7 @@ class MLUCnnl {
static void ConvBackpropInput(
const ExecutionContext& ctx, const cnnlConvolutionDescriptor_t conv_desc,
const cnnlTensorDescriptor_t input_desc, const void* filter,
const cnnlTensorDescriptor_t filter_desc, const void* filter,
const cnnlTensorDescriptor_t out_backprop_desc, const void* out_backprop,
const cnnlTensorDescriptor_t in_backprop_desc, void* in_backprop);
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.nn as nn
paddle.enable_static()
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if attrs['data_format'] == 'NHWC':
input_ = np.transpose(input_, [0, 3, 1, 2])
in_n, in_c, in_h, in_w = input_.shape
f_c, f_out_c, f_h, f_w = filter_.shape
groups = attrs['groups']
assert in_c == f_c
out_c = f_out_c * groups
sub_in_c = in_c // groups
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
# update pad and dilation
def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride):
padding = []
for input_size, filter_size, stride_size in zip(input_shape,
kernel_size,
kernel_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max(
((out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
ksize = filter_.shape[2:4]
if padding_algorithm == "VALID":
pad = [0, 0, 0, 0]
elif padding_algorithm == "SAME":
dilations = [1, 1]
input_data_shape = input_.shape[2:4]
pad = _get_padding_with_SAME(input_data_shape, ksize, stride)
pad_h_0, pad_h_1 = pad[0], pad[0]
pad_w_0, pad_w_1 = pad[1], pad[1]
if len(pad) == 4:
pad_h_0, pad_h_1 = pad[0], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[3]
d_bolck_h = dilations[0] * (f_h - 1) + 1
d_bolck_w = dilations[1] * (f_w - 1) + 1
out_h = (in_h - 1) * stride[0] + d_bolck_h
out_w = (in_w - 1) * stride[1] + d_bolck_w
if 'output_size' in attrs:
output_size = attrs['output_size']
out_h = output_size[0] + pad_h_0 + pad_h_1
out_w = output_size[1] + pad_w_0 + pad_w_1
out_pad_h = 0
out_pad_w = 0
if 'output_padding' in attrs:
out_pad_h = attrs['output_padding'][0]
out_pad_w = attrs['output_padding'][1]
out = np.zeros((in_n, out_c, out_h + out_pad_h, out_w + out_pad_w),
dtype=input_.dtype)
for n in range(in_n):
for i in range(in_h):
for j in range(in_w):
for g in range(groups):
input_masked = input_[n, g * sub_in_c:(g + 1) * sub_in_c, i,
j] # (c)
input_masked = np.reshape(input_masked, (sub_in_c, 1, 1))
input_masked = np.tile(input_masked, (1, f_h, f_w))
for k in range(f_out_c):
tmp_out = np.sum(
input_masked *
filter_[g * sub_in_c:(g + 1) * sub_in_c, k, :, :],
axis=0)
i1, i2 = i * stride[0], i * stride[0] + d_bolck_h
j1, j2 = j * stride[1], j * stride[1] + d_bolck_w
out[n, g * f_out_c + k, i1:i2:dilations[0],
j1:j2:dilations[1]] += tmp_out
out = out[:, :, pad_h_0:out_h - pad_h_1 + out_pad_h,
pad_w_0:out_w - pad_w_1 + out_pad_w]
if attrs['data_format'] == 'NHWC':
out = np.transpose(out, [0, 2, 3, 1])
return out
class TestConv2DTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.dtype = np.float32
self.set_mlu()
self.need_check_grad = True
self.is_test = False
self.use_cudnn = False
self.use_mkldnn = False
self.output_size = None
self.output_padding = []
self.data_format = "NCHW"
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_test_case()
input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'is_test': self.is_test,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
if self.output_size is not None:
self.attrs['output_size'] = self.output_size
if len(self.output_padding) > 0:
self.attrs['output_padding'] = self.output_padding
output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype(self.dtype)
self.outputs = {'Output': output}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_no_input(self):
if self.need_check_grad:
self.check_grad_with_place(self.place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def test_check_grad_no_filter(self):
if self.need_check_grad:
self.check_grad_with_place(self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']))
def test_check_grad(self):
if self.need_check_grad:
self.check_grad_with_place(self.place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02)
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose"
class TestWithSymmetricPad(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithAsymmetricPad(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithSAMEPad(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [2, 3, 6, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.padding_algorithm = 'SAME'
class TestWithVALIDPad(TestConv2DTransposeOp):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'VALID'
class TestWithGroups(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
class TestWithStride(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithDilation(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithEvenUpsample(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 3, 7, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 5, 5]
class Test_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithGroups_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithStride_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithDilation_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_padding = [1, 1]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestMLU_FP16(TestConv2DTransposeOp):
def init_test_case(self):
self.dtype = np.float16
self.set_mlu()
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def init_op_type(self):
self.need_check_grad = False
self.op_type = "conv2d_transpose"
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2)
class TestMLU_NHWC_FP16(TestMLU_FP16):
def init_test_case(self):
self.dtype = np.float16
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestMLUWithGroups_NHWC_FP16(TestMLU_FP16):
def init_test_case(self):
self.dtype = np.float16
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
class TestMLUWithEvenUpsample_NHWC_FP16(TestMLU_FP16):
def init_test_case(self):
self.dtype = np.float16
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
class TestConv2DTransposeAPI(unittest.TestCase):
def setUp(self):
self.set_mlu()
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_case1(self):
data1 = fluid.layers.data(name='data1',
shape=[3, 5, 5],
dtype='float32')
data2 = fluid.layers.data(name='data2',
shape=[5, 5, 3],
dtype='float32')
out1 = fluid.layers.conv2d_transpose(input=data1,
groups=1,
num_filters=6,
filter_size=3,
data_format='NCHW')
out2 = fluid.layers.conv2d_transpose(input=data2,
groups=1,
num_filters=6,
filter_size=3,
data_format='NHWC')
out3 = fluid.layers.conv2d_transpose(input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding=[[0, 0], [1, 1], [1, 1],
[0, 0]],
data_format='NHWC')
out4 = fluid.layers.conv2d_transpose(input=data1,
groups=3,
num_filters=6,
filter_size=3,
padding=[[0, 0], [0, 0], [2, 1],
[0, 0]],
data_format='NCHW')
out5 = fluid.layers.conv2d_transpose(input=data2,
groups=1,
num_filters=6,
filter_size=3,
padding='SAME',
data_format='NCHW')
out6 = fluid.layers.conv2d_transpose(input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding='VALID',
data_format='NHWC')
out7 = fluid.layers.conv2d_transpose(input=data1,
groups=1,
num_filters=6,
output_size=[7, 7],
padding=[0, 0],
data_format='NHWC')
data1_np = np.random.random((2, 3, 5, 5)).astype("float32")
data2_np = np.random.random((2, 5, 5, 3)).astype("float32")
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"data1": data1_np,
"data2": data2_np
},
fetch_list=[out1, out2, out3, out4, out5, out6, out7],
return_numpy=True)
self.assertIsNotNone(results[0])
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[2])
self.assertIsNotNone(results[3])
self.assertIsNotNone(results[4])
self.assertIsNotNone(results[5])
self.assertIsNotNone(results[6])
class TestConv2DTransposeOpException(unittest.TestCase):
def setUp(self):
self.set_mlu()
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_exception(self):
data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32")
def attr_data_format():
out = fluid.layers.conv2d_transpose(input=data,
groups=1,
num_filters=6,
filter_size=3,
data_format="NCDHW")
self.assertRaises(ValueError, attr_data_format)
def attr_padding_str():
out = fluid.layers.conv2d_transpose(input=data,
groups=1,
num_filters=6,
filter_size=3,
padding='Vald')
self.assertRaises(ValueError, attr_padding_str)
def attr_padding_list():
out = fluid.layers.conv2d_transpose(input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [1, 1], [0, 0],
[0, 0]])
self.assertRaises(ValueError, attr_padding_list)
def attr_padding_with_data_format():
out = fluid.layers.conv2d_transpose(input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [0, 0], [0, 0],
[1, 1]],
data_format='NHWC')
self.assertRaises(ValueError, attr_padding_with_data_format)
error_input = fluid.layers.data(name='error_data',
shape=[1],
dtype="float32")
def error_input_size():
out = fluid.layers.conv2d_transpose(input=error_input,
groups=1,
num_filters=6,
filter_size=3)
self.assertRaises(ValueError, error_input_size)
def error_groups():
out = fluid.layers.conv2d_transpose(input=data,
groups=0,
num_filters=6,
filter_size=3,
data_format='NHWC')
self.assertRaises(ValueError, error_groups)
class TestConv2DTransposeRepr(unittest.TestCase):
def setUp(self):
self.set_mlu()
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_case(self):
paddle.disable_static()
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = nn.Conv2DTranspose(4, 6, (3, 3), output_padding=1, stride=2)
print(conv)
y_var = conv(x_var)
y_np = y_var.numpy()
self.assertIsNotNone(y_np)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册