提交 200f07c2 编写于 作者: S sweetsky0901

add test

上级 ab03daa4
......@@ -20,7 +20,7 @@ namespace math {
// All tensors are in NCHW format
template <typename T>
class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -43,7 +43,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
// PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
output_data[index] = input_data[i];
}
input_data += input_feasize;
......@@ -57,7 +57,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
template <class T>
class Unpool2d_MaxGradFunctor<platform::CPUPlace, T> {
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -83,7 +83,7 @@ public:
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
// PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
......@@ -94,10 +94,10 @@ public:
}
};
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2d_MaxFunctor<platform::CPUPlace, float>;
template class Unpool2d_MaxFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -30,12 +30,11 @@ __global__ void KernelUnpool2dMax(const int nthreads,
const int output_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
// int output_feasize = output_height * output_width;
for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \
* output_height * output_width;
int out_index = indices_data[i];
// PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!");
PADDLE_ASSERT(out_index < (output_height * output_width));
output_data[out_offset + out_index] = input_data[i];
}
}
......@@ -52,13 +51,11 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
// int output_feasize = output_height * output_width;
for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \
* output_height * output_width;
int out_index = indices_data[i];
// PADDLE_ENFORCE(out_index < output_feasize,
// "err index in unpooling!");
PADDLE_ASSERT(out_index < (output_height * output_width));
input_grad[i] = output_grad[out_offset + out_index];
}
}
......@@ -66,7 +63,7 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -99,7 +96,7 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -135,11 +132,11 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
}
};
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2d_MaxFunctor<platform::GPUPlace, float>;
template class Unpool2d_MaxFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -26,7 +26,7 @@ namespace math {
template <typename Place, typename T>
class Unpool2d_MaxFunctor {
class Unpool2dMaxFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -35,7 +35,7 @@ class Unpool2d_MaxFunctor {
};
template <typename Place, class T>
class Unpool2d_MaxGradFunctor {
class Unpool2dMaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......
......@@ -49,11 +49,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"paddings(height, width) of unpooling operator.")
.SetDefault({0, 0});
AddAttr<std::string>("unpoolingType",
"(string), unpooling type, can be \"max\" for max-unpooling "
"and \"avg\" for average-unpooling.")
.InEnum({"max", "avg"});
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"input: the input Tensor to invert"
"indices: the indices given out by MaxPool2d"
"ksize – Size of the max pooling window."
"stride – Stride of the max pooling window."
"It is set to kernel_size by default."
"padding – Padding that was added to the input"
)DOC");
}
};
......@@ -82,8 +86,13 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Unpooling intput should be 4-D or 5-D tensor.");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput should be 4-D.");
for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!");
}
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
......
......@@ -37,7 +37,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
math::Unpool2d_MaxFunctor<Place, T> unpool2d_max_forward;
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
} break;
......@@ -70,7 +70,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
math::Unpool2d_MaxGradFunctor<Place, T> unpool2d_max_backward;
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
}
......
import unittest
import numpy as np
from op_test import OpTest
def maxout_forward_naive(input, groups):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
class TestUnpool2dOp(OpTest):
def setUp(self):
self.op_type = "unpool2d"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups).astype("float32")
self.inputs = {'X': input}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'unpooling_type': self.pool_type,
}
self.outputs = {'Out': output.astype('float32')}
def init_pool_type(self):
self.pool_type = "max"
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups=2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册