未验证 提交 966a6ce6 编写于 作者: S sweetsky0901 提交者: GitHub

Merge pull request #5826 from sweetsky0901/my_unpool_max_2d

My unpool max 2d
......@@ -191,6 +191,7 @@ set(DEPS_OPS
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
conv_op
conv_transpose_op
......@@ -235,6 +236,7 @@ op_library(adagrad_op DEPS selected_rows_functor)
op_library(conv_op DEPS vol2col)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
......
......@@ -13,8 +13,9 @@ if(WITH_GPU)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -26,8 +27,9 @@ else()
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/unpooling.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
output_data[index] = input_data[i];
}
input_data += input_feasize;
indices_data += input_feasize;
output_data += output_feasize;
}
}
}
};
template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const int* indices_data = indices.data<int>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
indices_data += input_feasize;
output_grad_data += output_feasize;
}
}
}
};
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/unpooling.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const int* indices_data,
const int input_height, const int input_width,
const int channels, T* output_data,
const int output_height,
const int output_width) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < out_c_stride);
output_data[out_offset + out_index] = input_data[i];
}
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_width, const int channels,
const T* output_data, const T* output_grad, const int output_height,
const int output_width, T* input_grad) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
}
}
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int threads = 1024;
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels,
output_data, output_height, output_width);
}
};
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int threads = 1024;
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels,
output_data, output_grad_data, output_height,
output_width, input_grad_data);
}
};
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
class Unpool2dMaxFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output);
};
template <typename Place, class T>
class Unpool2dMaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/unpool_op.h"
namespace paddle {
namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput(
"Indices",
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::vector<int>>(
"ksize",
"(vector), the unpooling window size(height, width) "
"of unpooling operator.");
AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), "
"strides (height, width) of unpooling operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator.")
.SetDefault({0, 0});
AddAttr<std::string>(
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf
)DOC");
}
};
int OutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size - 1) * stride - 2 * padding + ksize;
return output_size;
}
class UnpoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input(Indices) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnpoolOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput must be of 4-dimensional.");
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/unpool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
unpool_grad, ops::UnpoolGradKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/unpooling.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
auto* out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
T* output_data = out->mutable_data<T>(context.GetPlace());
if (output_data) {
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0));
}
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
};
template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out,
*out_grad, in_x_grad);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
s0, s1, s2, s3 = input.shape
out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0]
out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1]
out = np.zeros((s0, s1, out_hsize, out_wsize))
for nidx in xrange(s0):
for cidx in xrange(s1):
for h in xrange(s2):
for w in xrange(s3):
index = indices[nidx, cidx, h, w]
hidx = (index - index % out_wsize) / out_wsize
widx = index % out_wsize
out[nidx, cidx, int(hidx), int(widx)] = \
input[nidx, cidx, h, w]
return out
class TestUnpoolOp(OpTest):
def setUp(self):
self.op_type = "unpool"
self.init_test_case()
pre_input = np.random.random(self.shape).astype("float32")
nsize, csize, hsize, wsize = pre_input.shape
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \
self.strides[0] + 1
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \
self.strides[1] + 1
input = np.zeros((nsize, csize, hsize_out, wsize_out))
indices = np.zeros((nsize, csize, hsize_out, wsize_out))
for i in xrange(hsize_out):
for j in xrange(wsize_out):
r_start = np.max((i * self.strides[0] - self.paddings[0], 0))
r_end = np.min((i * self.strides[0] + self.ksize[0] - \
self.paddings[0], hsize))
c_start = np.max((j * self.strides[1] - self.paddings[1], 0))
c_end = np.min((j * self.strides[1] + self.ksize[1] - \
self.paddings[1], wsize))
for nidx in xrange(nsize):
for cidx in xrange(csize):
x_masked = pre_input[nidx, cidx, r_start:r_end, \
c_start:c_end]
input[nidx, cidx, i, j] = x_masked.max()
arg = x_masked.argmax()
indices[nidx, cidx, i, j] = \
(r_start + arg / self.ksize[1]) * wsize + \
c_start + arg % self.ksize[1]
output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32")
self.inputs = {
'X': input.astype('float32'),
'Indices': indices.astype('int32')
}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'unpooling_type': self.unpooling_type,
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max"
self.shape = [6, 4, 5, 5]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册