提交 bc45335e 编写于 作者: S sweetsky0901

add unpool

上级 6665c492
......@@ -139,6 +139,7 @@ set(DEPS_OPS
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
nccl_op
sequence_conv_op
......@@ -151,6 +152,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
if(WITH_GPU)
......
......@@ -14,6 +14,7 @@ if(WITH_GPU)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -26,6 +27,7 @@ else()
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/maxouting.h"
namespace paddle {
namespace operators {
namespace math {
// All tensors are in NCHW format
template <typename T>
class Unpool2d_Max_Functor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
if(index > output_feasize) {
//抛一个异常!
}
output_data[index] = input_data[i];
}
input_data += input_feasize;
indices_data += input_feasize;
output_data += output_feasize;
}
}
}
};
template <class T>
class Unpool2d_MaxGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int f = 0; f < input_feasize; ++f) {
int index = indices_data[i];
if(index > output_feasize) {
//抛一个异常!
}
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
indices_data += input_feasize;
output_grad_data += output_feasize;
}
}
}
};
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2d_MaxFunctor<platform::CPUPlace, float>;
template class Unpool2d_MaxFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/maxouting.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data,
const T* indices_data,
const int input_height,
const int input_width,
T* output_data,
const int output_height,
const int output_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \
* output_height * output_width;
int out_index = indices_data[i];
output_data[out_offset + out_index] = input_data[i];
}
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data,
const int input_height,
const int input_width,
const T* output_data,
const T* output_grad,
const int output_height,
const int output_width,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \
* output_height * output_width;
int out_index = indices_data[i];
input_grad[i] = output_grad[out_offset + out_index];
}
}
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelUnpool2dMax<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, indices_data,
input_height, input_width,
output_data, output_height, output_width);
}
};
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelUnpool2dMaxGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, indices_data,
input_height, input_width,
output_data, output_grad_data,
output_height, output_width,
input_grad_data);
}
};
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2d_MaxFunctor<platform::GPUPlace, float>;
template class Unpool2d_MaxFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX \
__FLT_MAX__
template <typename Place, typename T>
class Unpool2d_Max_Functor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output);
};
template <typename Place, class T>
class Unpool2d_Max_GradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/operators/unpool_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UnpoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput("Y",
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::vector<int>>("ksize",
"(vector ), the unpooling window size(height, width) "
"of unpooling operator.");
AddAttr<std::vector<int>>("strides", "(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator.")
.SetDefault({0, 0});
AddAttr<std::string>("unpoolingType",
"(string), unpooling type, can be \"max\" for max-unpooling "
"and \"avg\" for average-unpooling.")
.InEnum({"max", "avg"});
AddComment(R"DOC(
)DOC");
}
};
int OutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size -1) * stride - 2 * padding + ksize;
return output_size;
}
class UnpoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnpoolOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
std::string unpooling_type = ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Unpooling intput should be 4-D or 5-D tensor.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad,
ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool2d, ops::UnpoolKernel<paddle::platform::CPUPlace,
float>);
REGISTER_OP_CPU_KERNEL(unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CPUPlace,
float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/unpool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool2d,
ops::UnpoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/unpooling.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
Tensor* out = context.Output<Tensor>("Out");
std::string pooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
math::Unpool2d_Max_Functor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y,
ksize, strides, paddings, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
}
};
template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
}
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
math::UnpoolGradFunctor<Place, T> maxout_backward;
maxout_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out,
*out_grad, ksize, strides, paddings);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册