未验证 提交 7e31542c 编写于 作者: A andyjpaddle 提交者: GitHub

Add MaxUnPool3D op and MaxUnPool1D op (#38716)

* add maxunpool3d op

* update doc for maxunpool3d op

* update doc for maxunpool3d op

* update doc for maxunpool3d op

* update sample code for maxunpool3d

* add maxunpool1d op

* update some code for maxunpool1d
上级 2238a535
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -96,10 +96,101 @@ class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, T> {
}
}
};
template <typename T>
class Unpool3dMaxFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
int input_feasize = input_depth * input_height * input_width;
int output_feasize = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor depth * output tensor "
"height "
"* output tensor width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
output_data[index] = input_data[i];
}
input_data += input_feasize;
indices_data += input_feasize;
output_data += output_feasize;
}
}
}
};
template <class T>
class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
int input_feasize = input_depth * input_height * input_width;
int output_feasize = output_depth * output_height * output_width;
const int* indices_data = indices.data<int>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor depth * output tensor "
"height "
"* output tensor width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
indices_data += input_feasize;
output_grad_data += output_feasize;
}
}
}
};
template class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, float>;
template class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, double>;
template class Unpool2dMaxFunctor<platform::CPUDeviceContext, float>;
template class Unpool2dMaxFunctor<platform::CPUDeviceContext, double>;
template class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, float>;
template class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, double>;
template class Unpool3dMaxFunctor<platform::CPUDeviceContext, float>;
template class Unpool3dMaxFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -51,6 +51,45 @@ __global__ void KernelUnpool2dMaxGrad(
/*
* All tensors are in NCHW format.
*/
template <typename T>
__global__ void KernelUnpool3dMax(const int nthreads, const T* input_data,
const int* indices_data,
const int input_depth, const int input_height,
const int input_width, const int channels,
T* output_data, const int output_depth,
const int output_height,
const int output_width) {
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_depth / input_width / input_height) % channels;
int n = linearIndex / input_depth / input_width / input_height / channels;
output_data +=
(n * channels + c) * output_depth * output_height * output_width;
int maxind = indices_data[linearIndex];
output_data[maxind] = input_data[linearIndex];
}
}
template <typename T>
__global__ void KernelUnpool3dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_depth, const int input_height, const int input_width,
const int channels, const T* output_data, const T* output_grad,
const int output_depth, const int output_height, const int output_width,
T* input_grad) {
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_depth / input_width / input_height) % channels;
int n = linearIndex / input_depth / input_width / input_height / channels;
output_grad +=
(n * channels + c) * output_depth * output_height * output_width;
int maxind = indices_data[linearIndex];
input_grad[linearIndex] = output_grad[maxind];
}
}
/*
* All tensors are in NCDHW format.
*/
template <typename T>
class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
public:
......@@ -112,10 +151,82 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
output_width, input_grad_data);
}
};
template <typename T>
class Unpool3dMaxFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool3dMax<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_depth, input_height,
input_width, output_channels, output_data, output_depth, output_height,
output_width);
}
};
/*
* All tensors are in NCDHW format.
*/
template <typename T>
class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool3dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_depth, input_height,
input_width, output_channels, output_data, output_grad_data,
output_depth, output_height, output_width, input_grad_data);
}
};
template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, double>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, double>;
template class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, float>;
template class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, double>;
template class Unpool3dMaxFunctor<platform::CUDADeviceContext, float>;
template class Unpool3dMaxFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -33,6 +33,22 @@ class Unpool2dMaxGradFunctor {
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
};
template <typename DeviceContext, typename T>
class Unpool3dMaxFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output);
};
template <typename DeviceContext, class T>
class Unpool3dMaxGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -76,6 +76,65 @@ Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf
}
};
class Unpool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is the "
"number of channels, D, H and W is the depth, height and width of "
"feature.");
AddInput(
"Indices",
"(Tensor) The input tensor of the indices given out by MaxPool3d. "
"The format of input tensor is NCDHW. Where N is batch size, C is the "
"number of channels, D, H and W is the depth, height and width of "
"feature.");
AddOutput("Out",
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of feature.");
AddAttr<std::vector<int>>(
"ksize",
"(vector), the unpooling window size(depth, height, width) "
"of unpooling operator.");
AddAttr<std::vector<int>>(
"strides",
"(vector, default:{1, 1, 1}), "
"strides (depth, height, width) of unpooling operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector default:{0, 0,0}), "
"paddings (depth, height, width) of unpooling operator.")
.SetDefault({0, 0, 0});
AddAttr<std::string>(
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddAttr<std::vector<int>>("output_size",
"(vector, optional). The shape of output.")
.SetDefault({0, 0, 0});
AddAttr<std::string>(
"data_format",
"(string, default NCDHW)"
"Defaults to \"NCDHW\". Specify the data format of the output data, ")
.SetDefault("NCDHW");
AddComment(R"DOC(
Input shape is: $(N, C_{in}, D_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, D_{out}, H_{out}, W_{out})$, where
$$
D_{out} = (D_{in}-1) * strides[0] - 2 * paddings[0] + ksize[0] \\
H_{out} = (H_{in}-1) * strides[1] - 2 * paddings[1] + ksize[1] \\
W_{out} = (W_{in}-1) * strides[2] - 2 * paddings[2] + ksize[2]
$$
)DOC");
}
};
int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size - 1) * stride - 2 * padding + ksize;
return output_size;
......@@ -130,6 +189,55 @@ class UnpoolOp : public framework::OperatorWithKernel {
}
};
class Unpool3dOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Unpool3d");
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Unpool3d");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Unpool3d");
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 5, true,
platform::errors::InvalidArgument(
"Unpool Intput(X) must be of 5-dimensional, but "
"received Input(X)'s dimensions is %d.",
in_x_dims.size()));
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims,
platform::errors::InvalidArgument(
"The dimensions of Input(X) must equal to be"
"the dimensions of Input(Indices), but received"
"dimensions of Input(X) is [%d], received dimensions"
"of Input(Indices) is [%d]",
in_x_dims, in_y_dims));
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
} else {
output_shape.push_back(output_size[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
template <typename T>
class UnpoolOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -145,6 +253,21 @@ class UnpoolOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class Unpool3dOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Input("Indices"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -163,6 +286,26 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class Unpool3dOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Unpool3dGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "Unpool3dGrad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
......@@ -179,3 +322,16 @@ REGISTER_OP_CPU_KERNEL(
unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(unpool3d, ops::Unpool3dOp, ops::Unpool3dOpMaker,
ops::Unpool3dOpGradMaker<paddle::framework::OpDesc>,
ops::Unpool3dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unpool3d_grad, ops::Unpool3dOpGrad);
REGISTER_OP_CPU_KERNEL(
unpool3d, ops::Unpool3dKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unpool3dKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
unpool3d_grad,
ops::Unpool3dGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unpool3dGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -22,3 +22,10 @@ REGISTER_OP_CUDA_KERNEL(
unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
unpool3d, ops::Unpool3dKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unpool3dKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
unpool3d_grad,
ops::Unpool3dGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unpool3dGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -69,5 +69,54 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
unpool2d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad);
}
};
template <typename DeviceContext, typename T>
class Unpool3dKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
auto* out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
T* output_data = out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
if (output_data) {
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, out, static_cast<T>(0));
}
math::Unpool3dMaxFunctor<DeviceContext, T> unpool3d_max_forward;
unpool3d_max_forward(dev_ctx, *in_x, *in_y, out);
}
};
template <typename DeviceContext, typename T>
class Unpool3dGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
math::Unpool3dMaxGradFunctor<DeviceContext, T> unpool3d_max_backward;
unpool3d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
paddle.enable_static()
paddle.seed(2022)
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d]
+ kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
ret = output_size
return ret
def unpool1dmax_forward_naive(input, indices, ksize, strides, paddings,
output_size):
s0, s1, s2 = input.shape
output_size = _unpool_output_size(input, ksize, strides, paddings,
output_size)
out_lsize = output_size[0]
out = np.zeros((s0, s1, out_lsize))
for nidx in range(s0):
for cidx in range(s1):
for l in range(s2):
index = indices[nidx, cidx, l]
lidx = index % out_lsize
out[nidx, cidx, lidx] = input[nidx, cidx, l]
return out
class TestUnpool1DOpAPI_dygraph(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool1d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=2)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool1DOpAPI_dygraph2(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool1d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=None)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool1DOpAPI_dygraph3(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
Pool1d = paddle.nn.MaxPool1D(
kernel_size=2, stride=2, return_mask=True)
UnPool1d = paddle.nn.MaxUnPool1D(kernel_size=2, stride=2)
output, indices = Pool1d(input_x)
output_unpool = UnPool1d(output, indices)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool1DOpAPI_static(unittest.TestCase):
def test_case(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_data = np.array([[[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]]).astype("float32")
x = paddle.fluid.data(
name='x', shape=[1, 3, 4], dtype='float32')
output, indices = F.max_pool1d(
x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=None)
exe = paddle.fluid.Executor(place)
fetches = exe.run(paddle.fluid.default_main_program(),
feed={"x": input_data},
fetch_list=[output_unpool],
return_numpy=True)
pool1d_out_np = np.array(
[[[2., 4.], [6., 8.], [10., 12.]]]).astype("float32")
indices_np = np.array(
[[[1, 3], [1, 3], [1, 3]]]).astype("int32")
expected_output_unpool = unpool1dmax_forward_naive(
pool1d_out_np, indices_np, [2], [2], [0], [4])
self.assertTrue(np.allclose(fetches[0], expected_output_unpool))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
paddle.enable_static()
paddle.seed(2022)
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d]
+ kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
ret = output_size
return ret
def unpool3dmax_forward_naive(input, indices, ksize, strides, paddings,
output_size):
s0, s1, s2, s3, s4 = input.shape
output_size = _unpool_output_size(input, ksize, strides, paddings,
output_size)
out_dsize = output_size[0]
out_hsize = output_size[1]
out_wsize = output_size[2]
out = np.zeros((s0, s1, out_dsize, out_hsize, out_wsize))
for nidx in range(s0):
for cidx in range(s1):
for d in range(s2):
for h in range(s3):
for w in range(s4):
index = indices[nidx, cidx, d, h, w]
didx = index // (out_wsize * out_hsize)
hidx = (
index - didx * out_hsize * out_wsize) // out_wsize
widx = (
index - didx * out_hsize * out_wsize) % out_wsize
out[nidx, cidx, didx, hidx, widx] = \
input[nidx, cidx, d, h, w]
return out
class TestUnpool3DOp(OpTest):
def setUp(self):
self.op_type = "unpool3d"
self.init_test_case()
inputs = np.random.randint(0, 100, self.shape)
nsize, csize, dsize, hsize, wsize = inputs.shape
self.output_size = _unpool_output_size(inputs, self.ksize, self.strides,
self.paddings, self.output_size)
indices = np.random.permutation(
np.arange(0, self.output_size[0] * self.output_size[1] *
self.output_size[2]))[:dsize * hsize * wsize]
indices = np.reshape(indices, [dsize, hsize, wsize])
idx_list = []
for n in range(nsize):
c_list = []
for c in range(csize):
c_list.append(indices.tolist())
idx_list.append(c_list)
indices = np.array(idx_list)
output = self.unpool3d_forward_naive(inputs, indices, self.ksize, \
self.strides, self.paddings, self.output_size).astype("float64")
self.inputs = {
'X': inputs.astype('float64'),
'Indices': indices.astype('int32')
}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'unpooling_type': self.unpooling_type,
'output_size': self.output_size,
}
self.outputs = {'Out': output.astype('float64')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.unpool3d_forward_naive = unpool3dmax_forward_naive
self.unpooling_type = "max"
self.shape = [1, 1, 4, 5, 6]
self.ksize = [2, 2, 2]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
self.output_size = None
class TestUnpool3DOpcase1(TestUnpool3DOp):
def init_test_case(self):
self.unpool3d_forward_naive = unpool3dmax_forward_naive
self.unpooling_type = "max"
self.shape = [1, 3, 4, 5, 6]
self.ksize = [2, 2, 2]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
self.output_size = None
class TestUnpool3DOpOutput(TestUnpool3DOp):
def init_test_case(self):
self.unpool3d_forward_naive = unpool3dmax_forward_naive
self.unpooling_type = "max"
self.shape = [1, 3, 4, 5, 6]
self.ksize = [2, 2, 2]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
self.output_size = [7, 9, 11]
class TestUnpool3DOpException(unittest.TestCase):
def test_exception(self):
def indices_size_error():
data = paddle.randint(shape=[1, 1, 3, 3, 3])
indices = paddle.reshape(
paddle.arange(0, 36), shape=[1, 1, 3, 3, 4])
MaxUnPool3D = F.maxunpool3d(data, indices, kernel_size=2, stride=2)
def indices_value_error():
data = paddle.randint(shape=[1, 1, 3, 3, 3])
indices = paddle.reshape(
paddle.arange(4, 40), shape=[1, 1, 3, 3, 3])
MaxUnPool3D = F.maxunpool3d(data, indices, kernel_size=2, stride=2)
def data_format_error():
data = paddle.randint(shape=[1, 1, 3, 3, 3])
indices = paddle.reshape(
paddle.arange(0, 27), shape=[1, 1, 3, 3, 3])
MaxUnPool3D = F.maxunpool3d(
data, indices, kernel_size=2, stride=2, data_format="NDHWC")
def data_outputsize_error():
data = paddle.randint(shape=[1, 1, 3, 3, 3])
indices = paddle.reshape(
paddle.arange(0, 27), shape=[1, 1, 3, 3, 3])
MaxUnPool3D = F.maxunpool3d(
data,
indices,
kernel_size=2,
stride=2,
output_size=[2, 2, 3, 4, 5])
def data_outputsize_error2():
data = paddle.randint(shape=[1, 1, 3, 3, 3])
indices = paddle.reshape(
paddle.arange(0, 27), shape=[1, 1, 3, 3, 3])
MaxUnPool3D = F.maxunpool3d(
data,
indices,
kernel_size=2,
stride=2,
output_size=[10, 10, 10])
self.assertRaises(ValueError, indices_size_error)
self.assertRaises(ValueError, indices_value_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, data_outputsize_error)
self.assertRaises(ValueError, data_outputsize_error2)
class TestUnpool3DOpAPI_dygraph(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 4, 4, 6)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool3d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool3d(
output, indices, kernel_size=2, stride=2)
expected_output_unpool = unpool3dmax_forward_naive(
output.numpy(),
indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool3DOpAPI_dygraph2(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 4, 4, 6)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool3d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool3d(
output, indices, kernel_size=2, stride=None)
expected_output_unpool = unpool3dmax_forward_naive(
output.numpy(),
indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool3DOpAPI_dygraph3(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 4, 4, 6)
input_x = paddle.to_tensor(input_data)
Pool3d = paddle.nn.MaxPool3D(
kernel_size=2, stride=2, return_mask=True)
UnPool3d = paddle.nn.MaxUnPool3D(kernel_size=2, stride=2)
output, indices = Pool3d(input_x)
output_unpool = UnPool3d(output, indices)
expected_output_unpool = unpool3dmax_forward_naive(
output.numpy(),
indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))
paddle.enable_static()
class TestUnpool3DOpAPI_static(unittest.TestCase):
def test_case(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_data = np.array([[[[[1, 2, 3, 4], [5, 6, 7, 8], \
[9, 10, 11, 12], [13, 14, 15, 16]], [[1, 2, 3, 4], [5, 6, 7, 8], \
[9, 10, 11, 12], [13, 14, 15, 16]]]]]).astype("float32")
x = paddle.fluid.data(
name='x', shape=[1, 1, 2, 4, 4], dtype='float32')
output, indices = F.max_pool3d(
x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool3d(
output, indices, kernel_size=2, stride=None)
exe = paddle.fluid.Executor(place)
fetches = exe.run(paddle.fluid.default_main_program(),
feed={"x": input_data},
fetch_list=[output_unpool],
return_numpy=True)
pool3d_out_np = np.array(
[[[[[6., 8.], [14., 16.]]]]]).astype("float32")
indices_np = np.array([[[[[5, 7], [13, 15]]]]]).astype("int32")
expected_output_unpool = unpool3dmax_forward_naive(
pool3d_out_np, indices_np, [2, 2, 2], [2, 2, 2], [0, 0, 0],
[2, 4, 4])
self.assertTrue(np.allclose(fetches[0], expected_output_unpool))
if __name__ == '__main__':
unittest.main()
......@@ -77,7 +77,9 @@ from .layer.pooling import AvgPool3D # noqa: F401
from .layer.pooling import MaxPool1D # noqa: F401
from .layer.pooling import MaxPool2D # noqa: F401
from .layer.pooling import MaxPool3D # noqa: F401
from .layer.pooling import MaxUnPool1D # noqa: F401
from .layer.pooling import MaxUnPool2D # noqa: F401
from .layer.pooling import MaxUnPool3D # noqa: F401
from .layer.pooling import AdaptiveAvgPool1D # noqa: F401
from .layer.pooling import AdaptiveAvgPool2D # noqa: F401
from .layer.pooling import AdaptiveAvgPool3D # noqa: F401
......@@ -301,6 +303,8 @@ __all__ = [ #noqa
'ReLU6',
'LayerDict',
'ZeroPad2D',
'MaxUnPool1D',
'MaxUnPool2D',
'MaxUnPool3D',
'HingeEmbeddingLoss',
]
......@@ -107,7 +107,9 @@ from .pooling import adaptive_max_pool3d # noqa: F401
from .pooling import adaptive_avg_pool1d # noqa: F401
from .pooling import adaptive_avg_pool2d # noqa: F401
from .pooling import adaptive_avg_pool3d # noqa: F401
from .pooling import max_unpool1d # noqa: F401
from .pooling import max_unpool2d # noqa: F401
from .pooling import max_unpool3d # noqa: F401
from .vision import affine_grid # noqa: F401
from .vision import grid_sample # noqa: F401
......@@ -179,7 +181,9 @@ __all__ = [ #noqa
'max_pool1d',
'max_pool2d',
'max_pool3d',
'max_unpool1d',
'max_unpool2d',
'max_unpool3d',
'adaptive_avg_pool1d',
'adaptive_avg_pool2d',
'adaptive_avg_pool3d',
......
......@@ -664,6 +664,115 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
return ret
def max_unpool1d(x,
indices,
kernel_size,
stride=None,
padding=0,
data_format="NCL",
output_size=None,
name=None):
"""
This API implements max unpooling 1d opereation.
`max_unpool1d` accepts the output of `max_pool1d` as input,
including the indices of the maximum value and calculate the partial inverse.
All non-maximum values ​​are set to zero.
- Input: :math:`(N, C, L_{in})`
- Output: :math:`(N, C, L_{out})`, where
.. math::
L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size
or as given by :attr:`output_size` in the call operator.
Args:
x (Tensor): The input tensor of unpooling operator which is a 3-D tensor with
shape [N, C, L]. The format of input tensor is `"NCL"`,
where `N` is batch size, `C` is the number of channels, `L` is
the length of the feature. The data type is float32 or float64.
indices (Tensor): The indices given out by maxpooling1d which is a 3-D tensor with
shape [N, C, L]. The format of input tensor is `"NCL"` ,
where `N` is batch size, `C` is the number of channels, `L` is
the length of the featuree. The data type is float32 or float64.
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, stride, padding).
data_format (string): The data format of the input and output data.
The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of:
`[batch_size, input_channels, input_length]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of unpooling result.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
data = paddle.rand(shape=[1, 3, 16])
pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 3, 8], indices shape: [1, 3, 8]
unpool_out = F.max_unpool1d(pool_out, indices, kernel_size=2, padding=0)
# unpool_out shape: [1, 3, 16]
"""
"""NCL to NCHW"""
if data_format not in ["NCL"]:
raise ValueError("Attr(data_format) should be 'NCL'. Received "
"Attr(data_format): %s." % str(data_format))
data_format = "NCHW"
x = unsqueeze(x, [2])
indices = unsqueeze(indices, [2])
kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size')
if stride is None:
stride = kernel_size
else:
stride = [1] + utils.convert_to_list(stride, 1, 'pool_stride')
padding, padding_algorithm = _update_padding_nd(padding, 1)
# use 2d to implenment 1d should expand padding in advance.
padding = _expand_low_nd_padding(padding)
output_size = _unpool_output_size(x, kernel_size, stride, padding,
output_size)
if in_dygraph_mode():
output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize',
kernel_size, 'strides', stride, 'paddings',
padding, "output_size", output_size,
"data_format", data_format)
return squeeze(output, [2])
op_type = "unpool"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name="x")
unpool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=op_type,
inputs={"X": x,
"Indices": indices},
outputs={"Out": unpool_out},
attrs={
"unpooling_type": "max",
"ksize": kernel_size,
"strides": stride,
"paddings": padding,
"output_size": output_size
})
return squeeze(unpool_out, [2])
def max_unpool2d(x,
indices,
kernel_size,
......@@ -779,6 +888,118 @@ def max_unpool2d(x,
return unpool_out
def max_unpool3d(x,
indices,
kernel_size,
stride=None,
padding=0,
data_format="NCDHW",
output_size=None,
name=None):
"""
This API implements max unpooling 3d opereation.
`max_unpool3d` accepts the output of `max_pool3d` as input,
including the indices of the maximum value and calculate the partial inverse.
All non-maximum values ​​are set to zero.
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
.. math::
D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0]
.. math::
H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1]
.. math::
W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2]
or as given by :attr:`output_size` in the call operator
Args:
x (Tensor): The input tensor of unpooling operator which is a 5-D tensor with
shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"`,
where `N` is batch size, `C` is the number of channels, `D` is
the depth of the feature, `H` is the height of the feature,
and `W` is the width of the feature. The data type is float32 or float64.
indices (Tensor): The indices given out by maxpooling3d which is a 5-D tensor with
shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` ,
where `N` is batch size, `C` is the number of channels, `D` is
the depth of the feature, `H` is the height of the feature,
and `W` is the width of the feature. The data type is float32 or float64.
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, stride, padding).
data_format (string): The data format of the input and output data.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of unpooling result.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
data = paddle.rand(shape=[1, 1, 4, 4, 6])
pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 1, 2, 2, 3], indices shape: [1, 1, 2, 2, 3]
unpool_out = F.max_unpool3d(pool_out, indices, kernel_size=2, padding=0)
# unpool_out shape: [1, 1, 4, 4, 6]
"""
kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size')
if stride is None:
stride = kernel_size
else:
stride = utils.convert_to_list(stride, 3, 'pool_stride')
padding = utils.convert_to_list(padding, 3, 'padding')
if data_format not in ["NCDHW"]:
raise ValueError("Attr(data_format) should be 'NCDHW'. Received "
"Attr(data_format): %s." % str(data_format))
output_size = _unpool_output_size(x, kernel_size, stride, padding,
output_size)
if in_dygraph_mode():
output = _C_ops.unpool3d(x, indices, 'unpooling_type', 'max', 'ksize',
kernel_size, 'strides', stride, 'paddings',
padding, "output_size", output_size,
"data_format", data_format)
return output
op_type = "unpool3d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name="x")
unpool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=op_type,
inputs={"X": x,
"Indices": indices},
outputs={"Out": unpool_out},
attrs={
"unpooling_type": "max",
"ksize": kernel_size,
"strides": stride,
"paddings": padding,
"output_size": output_size
})
return unpool_out
def max_pool2d(x,
kernel_size,
stride=None,
......
......@@ -57,7 +57,9 @@ from .pooling import AdaptiveAvgPool3D # noqa: F401
from .pooling import AdaptiveMaxPool1D # noqa: F401
from .pooling import AdaptiveMaxPool2D # noqa: F401
from .pooling import AdaptiveMaxPool3D # noqa: F401
from .pooling import MaxUnPool1D # noqa: F401
from .pooling import MaxUnPool2D # noqa: F401
from .pooling import MaxUnPool3D # noqa: F401
from .conv import Conv1D # noqa: F401
from .conv import Conv2D # noqa: F401
from .conv import Conv3D # noqa: F401
......
......@@ -1130,6 +1130,88 @@ class AdaptiveMaxPool3D(Layer):
self._return_mask)
class MaxUnPool1D(Layer):
"""
This API implements max unpooling 1d opereation.
`max_unpool1d` accepts the output of `max_pool1d` as input,
including the indices of the maximum value and calculate the partial inverse.
All non-maximum values ​​are set to zero.
- Input: :math:`(N, C, L_{in})`
- Output: :math:`(N, C, L_{out})`, where
.. math::
L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size
or as given by :attr:`output_size` in the call operator.
Parameters:
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, stride, padding).
data_format (string): The data format of the input and output data.
The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of:
`[batch_size, input_channels, input_length]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
A callable object of MaxUnPool1D.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
data = paddle.rand(shape=[1, 3, 16])
pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 3, 8], indices shape: [1, 3, 8]
Unpool1D = paddle.nn.MaxUnPool1D(kernel_size=2, padding=0)
unpool_out = Unpool1D(pool_out, indices)
# unpool_out shape: [1, 3, 16]
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
data_format="NCL",
output_size=None,
name=None):
super(MaxUnPool1D, self).__init__()
self.ksize = kernel_size
self.stride = stride
self.padding = padding
self.data_format = data_format
self.output_size = output_size
self.name = name
def forward(self, x, indices):
return F.max_unpool1d(
x,
indices,
kernel_size=self.ksize,
stride=self.stride,
padding=self.padding,
data_format=self.data_format,
output_size=self.output_size,
name=self.name)
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
class MaxUnPool2D(Layer):
"""
This API implements max unpooling 2d opereation.
......@@ -1214,3 +1296,92 @@ class MaxUnPool2D(Layer):
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
class MaxUnPool3D(Layer):
"""
This API implements max unpooling 3d opereation.
`max_unpool3d` accepts the output of `max_pool3d` as input,
including the indices of the maximum value and calculate the partial inverse.
All non-maximum values ​​are set to zero.
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
.. math::
D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0]
.. math::
H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1]
.. math::
W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2]
or as given by :attr:`output_size` in the call operator
Parameters:
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, stride, padding).
data_format (string): The data format of the input and output data.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
A callable object of MaxUnPool3D.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
data = paddle.rand(shape=[1, 1, 4, 4, 6])
pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 1, 2, 2, 3], indices shape: [1, 1, 2, 2, 3]
Unpool3D = paddle.nn.MaxUnPool3D(kernel_size=2, padding=0)
unpool_out = Unpool3D(pool_out, indices)
# unpool_out shape: [1, 1, 4, 4, 6]
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
data_format="NCDHW",
output_size=None,
name=None):
super(MaxUnPool3D, self).__init__()
self.ksize = kernel_size
self.stride = stride
self.padding = padding
self.data_format = data_format
self.output_size = output_size
self.name = name
def forward(self, x, indices):
return F.max_unpool3d(
x,
indices,
kernel_size=self.ksize,
stride=self.stride,
padding=self.padding,
data_format=self.data_format,
output_size=self.output_size,
name=self.name)
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册