提交 45a8c9dd 编写于 作者: S sweetsky0901

add unpool2d make ok

上级 f638f910
...@@ -80,6 +80,13 @@ function(op_library TARGET) ...@@ -80,6 +80,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n") file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif() endif()
# unpool_op contains several operators
if ("${TARGET}" STREQUAL "unpool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(unpool2d);\n")
endif()
# pool_cudnn_op contains several operators # pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op") if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1) set(pybind_flag 1)
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/maxouting.h" #include "paddle/operators/math/unpooling.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,7 +20,7 @@ namespace math { ...@@ -20,7 +20,7 @@ namespace math {
// All tensors are in NCHW format // All tensors are in NCHW format
template <typename T> template <typename T>
class Unpool2d_Max_Functor<platform::CPUPlace, T> { class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -36,16 +36,14 @@ class Unpool2d_Max_Functor<platform::CPUPlace, T> { ...@@ -36,16 +36,14 @@ class Unpool2d_Max_Functor<platform::CPUPlace, T> {
int input_feasize = input_height * input_width; int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width; int output_feasize = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>(); const int * indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) { for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i]; int index = indices_data[i];
if(index > output_feasize) { // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
//抛一个异常!
}
output_data[index] = input_data[i]; output_data[index] = input_data[i];
} }
input_data += input_feasize; input_data += input_feasize;
...@@ -70,26 +68,22 @@ public: ...@@ -70,26 +68,22 @@ public:
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output->dims()[1]; const int output_channels = output.dims()[1];
const int output_height = output->dims()[2]; const int output_height = output.dims()[2];
const int output_width = output->dims()[3]; const int output_width = output.dims()[3];
int input_feasize = input_height * input_width; int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width; int output_feasize = output_height * output_width;
const T* input_data = input.data<T>(); const int* indices_data = indices.data<int>();
const T* indices_data = indices.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int f = 0; f < input_feasize; ++f) { for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i]; int index = indices_data[i];
if(index > output_feasize) { // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
//抛一个异常!
}
input_grad_data[i] = output_grad_data[index]; input_grad_data[i] = output_grad_data[index];
} }
input_grad_data += input_feasize; input_grad_data += input_feasize;
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/maxouting.h" #include "paddle/operators/math/unpooling.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
...@@ -22,7 +22,7 @@ namespace math { ...@@ -22,7 +22,7 @@ namespace math {
template <typename T> template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads, __global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data, const T* input_data,
const T* indices_data, const int* indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
T* output_data, T* output_data,
...@@ -30,16 +30,19 @@ __global__ void KernelUnpool2dMax(const int nthreads, ...@@ -30,16 +30,19 @@ __global__ void KernelUnpool2dMax(const int nthreads,
const int output_width) { const int output_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
// int output_feasize = output_height * output_width;
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \ int out_offset = i / (input_height * input_width) \
* output_height * output_width; * output_height * output_width;
int out_index = indices_data[i]; int out_index = indices_data[i];
// PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!");
output_data[out_offset + out_index] = input_data[i]; output_data[out_offset + out_index] = input_data[i];
} }
} }
template <typename T> template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads, __global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data, const T* input_data,
const int* indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
const T* output_data, const T* output_data,
...@@ -49,10 +52,13 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, ...@@ -49,10 +52,13 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
T* input_grad) { T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
// int output_feasize = output_height * output_width;
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int out_offset = i / (input_height * input_width) \ int out_offset = i / (input_height * input_width) \
* output_height * output_width; * output_height * output_width;
int out_index = indices_data[i]; int out_index = indices_data[i];
// PADDLE_ENFORCE(out_index < output_feasize,
// "err index in unpooling!");
input_grad[i] = output_grad[out_offset + out_index]; input_grad[i] = output_grad[out_offset + out_index];
} }
} }
...@@ -72,10 +78,8 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> { ...@@ -72,10 +78,8 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
const int output_channels = output->dims()[1]; const int output_channels = output->dims()[1];
const int output_height = output->dims()[2]; const int output_height = output->dims()[2];
const int output_width = output->dims()[3]; const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>(); const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel(); int nthreads = output->numel();
...@@ -99,19 +103,18 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> { ...@@ -99,19 +103,18 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad, framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad) {
int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[1];
const int output_height = output.dims()[2]; const int output_height = output.dims()[2];
const int output_width = output.dims()[3]; const int output_width = output.dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>(); const int* indices_data = indices.data<int>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
......
...@@ -26,7 +26,7 @@ namespace math { ...@@ -26,7 +26,7 @@ namespace math {
template <typename Place, typename T> template <typename Place, typename T>
class Unpool2d_Max_Functor { class Unpool2d_MaxFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -35,10 +35,11 @@ class Unpool2d_Max_Functor { ...@@ -35,10 +35,11 @@ class Unpool2d_Max_Functor {
}; };
template <typename Place, class T> template <typename Place, class T>
class Unpool2d_Max_GradFunctor { class Unpool2d_MaxGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad, framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad); const framework::Tensor& output_grad);
......
...@@ -20,7 +20,8 @@ using framework::Tensor; ...@@ -20,7 +20,8 @@ using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
UnpoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) Unpool2dOpMaker(framework::OpProto* proto, \
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of unpool operator. " "(Tensor) The input tensor of unpool operator. "
...@@ -39,10 +40,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -39,10 +40,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("ksize", AddAttr<std::vector<int>>("ksize",
"(vector ), the unpooling window size(height, width) " "(vector ), the unpooling window size(height, width) "
"of unpooling operator."); "of unpooling operator.");
AddAttr<std::vector<int>>("strides", "(vector, default:{1, 1}), " AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator.") "strides(height, width) of unpooling operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "(vector defalut:{0,0}), " AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator.") "paddings(height, width) of unpooling operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<std::string>("unpoolingType", AddAttr<std::string>("unpoolingType",
...@@ -73,7 +76,8 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -73,7 +76,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y"); auto in_y_dims = ctx->GetInputDim("Y");
std::string unpooling_type = ctx->Attrs().Get<std::string>("unpooling_type"); std::string unpooling_type = \
ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
...@@ -95,7 +99,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ...@@ -95,7 +99,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
...@@ -109,8 +113,11 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ...@@ -109,8 +113,11 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad, REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad,
ops::UnpoolOpGrad); ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool2d, ops::UnpoolKernel<paddle::platform::CPUPlace, REGISTER_OP_CPU_KERNEL(unpool2d,
float>); ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(unpool2d_grad, REGISTER_OP_CPU_KERNEL(unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, ops::UnpoolGradKernel<paddle::platform::CPUPlace,
float>); float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace,
double>);
...@@ -16,7 +16,10 @@ ...@@ -16,7 +16,10 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool2d, REGISTER_OP_GPU_KERNEL(unpool2d,
ops::UnpoolKernel<paddle::platform::GPUPlace, float>); ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(unpool2d_grad, REGISTER_OP_GPU_KERNEL(unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::GPUPlace, ops::UnpoolGradKernel<paddle::platform::GPUPlace,
float>); float>,
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
double>);
...@@ -37,9 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -37,9 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> {
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
if (pooling_type == "max") { if (pooling_type == "max") {
math::Unpool2d_Max_Functor<Place, T> unpool2d_max_forward; math::Unpool2d_MaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
ksize, strides, paddings, out);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D input."); } default: { PADDLE_THROW("Pool op only supports 2D input."); }
...@@ -71,12 +70,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -71,12 +70,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
if (pooling_type == "max") { if (pooling_type == "max") {
math::UnpoolGradFunctor<Place, T> maxout_backward; math::Unpool2d_MaxGradFunctor<Place, T> unpool2d_max_backward;
maxout_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out, unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out_grad, ksize, strides, paddings); *out, *out_grad);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D input."); } default: { PADDLE_THROW("Unpool op only supports 2D input."); }
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册