提交 bd561384 编写于 作者: S sweetsky0901

format code

上级 d9673cad
...@@ -17,15 +17,13 @@ limitations under the License. */ ...@@ -17,15 +17,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// All tensors are in NCHW format // All tensors are in NCHW format
template <typename T> template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> { class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices, framework::Tensor* output) {
framework::Tensor * output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -40,7 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> { ...@@ -40,7 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) { for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i]; int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
output_data[index] = input_data[i]; output_data[index] = input_data[i];
} }
...@@ -51,9 +49,6 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> { ...@@ -51,9 +49,6 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
} }
} }
}; };
template <class T> template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> { class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public: public:
...@@ -62,7 +57,7 @@ public: ...@@ -62,7 +57,7 @@ public:
const framework::Tensor& indices, const framework::Tensor& indices,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
framework::Tensor * input_grad) { framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -89,12 +84,10 @@ public: ...@@ -89,12 +84,10 @@ public:
} }
} }
}; };
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>; template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>; template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>; template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>; template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,36 +18,33 @@ limitations under the License. */ ...@@ -18,36 +18,33 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads, __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const T* input_data, const int* indices_data,
const int * indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
const int channels, const int channels,
T* output_data, T* output_data,
const int output_height, const int output_height,
const int output_width) { const int output_width) {
int in_n_stride = input_height * input_width * channels; int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width; int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels; int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width; int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride; int bidx = i / in_n_stride;
int boffset = i % in_n_stride; int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride; int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i]; int out_index = indices_data[i];
PADDLE_ASSERT(out_index < out_c_stride); PADDLE_ASSERT(out_index < out_c_stride);
output_data[out_offset + out_index] = input_data[i]; output_data[out_offset + out_index] = input_data[i];
} }
} }
template <typename T> template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads, __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data,
const T* input_data,
const int* indices_data, const int* indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
...@@ -57,32 +54,32 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, ...@@ -57,32 +54,32 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
const int output_height, const int output_height,
const int output_width, const int output_width,
T* input_grad) { T* input_grad) {
int in_n_stride = input_height * input_width * channels; int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width; int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels; int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width; int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride; int bidx = i / in_n_stride;
int boffset = i % in_n_stride; int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride; int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i]; int out_index = indices_data[i];
PADDLE_ASSERT(out_index < out_c_stride); PADDLE_ASSERT(out_index < out_c_stride);
input_grad[i] = output_grad[out_offset + out_index]; input_grad[i] = output_grad[out_offset + out_index];
} }
} }
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
*/ */
template <typename T> template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> { class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
framework::Tensor * output) { framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -93,7 +90,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> { ...@@ -93,7 +90,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
const int* indices_data = indices.data<int>(); const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax< KernelUnpool2dMax<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
...@@ -107,13 +104,13 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> { ...@@ -107,13 +104,13 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
*/ */
template <typename T> template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
framework::Tensor * input_grad) { framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -126,24 +123,20 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { ...@@ -126,24 +123,20 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad< KernelUnpool2dMaxGrad<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input.numel(), input_data, indices_data, .stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels, input_height, input_width, output_channels,
output_data, output_grad_data, output_data, output_grad_data,
output_height, output_width, output_height, output_width, input_grad_data);
input_grad_data);
} }
}; };
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>; template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>; template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>; template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,22 +22,21 @@ namespace math { ...@@ -22,22 +22,21 @@ namespace math {
template <typename Place, typename T> template <typename Place, typename T>
class Unpool2dMaxFunctor { class Unpool2dMaxFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices, framework::Tensor* output);
framework::Tensor * output);
}; };
template <typename Place, class T> template <typename Place, class T>
class Unpool2dMaxGradFunctor { class Unpool2dMaxGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
framework::Tensor * input_grad); framework::Tensor* input_grad);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -21,107 +21,115 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -21,107 +21,115 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
Unpool2dOpMaker(framework::OpProto* proto, Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput(
"X",
"(Tensor) The input tensor of unpool operator. " "(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."); "number of channels, H and W is the height and width of feature.");
AddInput("Indices", AddInput(
"Indices",
"(Tensor) The input tensor of the indices given out by MaxPool2d. " "(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."); "number of channels, H and W is the height and width of feature.");
AddOutput("Out", AddOutput(
"Out",
"(Tensor) The output tensor of unpool operator." "(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<std::vector<int>>("ksize", AddAttr<std::vector<int>>(
"ksize",
"(vector), the unpooling window size(height, width) " "(vector), the unpooling window size(height, width) "
"of unpooling operator."); "of unpooling operator.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>(
"strides",
"(vector, default:{1, 1}), " "(vector, default:{1, 1}), "
"strides (height, width) of unpooling operator.") "strides (height, width) of unpooling operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"paddings",
"(vector defalut:{0,0}), " "(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator.") "paddings (height, width) of unpooling operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<std::string>("unpooling_type", AddAttr<std::string>(
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ") "(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"}); .InEnum({"max"});
AddComment(R"DOC( AddComment(R"DOC(
"Input shape: $(N, C_{in}, H_{in}, W_{in})$ "Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$ Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where Where
$$ $$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\ H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1] W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$ $$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf /07/iccv2011.pdf
)DOC"); )DOC");
} }
}; };
int OutputSize(int input_size, int ksize, int padding, int stride) { int OutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size -1) * stride - 2 * padding + ksize; int output_size = (input_size - 1) * stride - 2 * padding + ksize;
return output_size; return output_size;
} }
class UnpoolOp : public framework::OperatorWithKernel { class UnpoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp" PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input(Indices) of UnpoolOp"
"should not be null."); "should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input(Indices) of UnpoolOp" PADDLE_ENFORCE(ctx->HasOutput("Out"),
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnpoolOp should not be null."); "Output(Out) of UnpoolOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices"); auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type = std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpooling_type"); ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings =
PADDLE_ENFORCE(in_x_dims.size() == 4, ctx->Attrs().Get<std::vector<int>>("paddings");
"Unpooling intput must be of 4-dimensional."); PADDLE_ENFORCE(in_x_dims.size() == 4,
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); "Unpooling intput must be of 4-dimensional.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
for (size_t i = 0; i < ksize.size(); ++i) { std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
output_shape.push_back( for (size_t i = 0; i < ksize.size(); ++i) {
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); output_shape.push_back(
} OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); }
} ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
}; };
class UnpoolOpGrad : public framework::OperatorWithKernel { class UnpoolOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -129,10 +137,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ...@@ -129,10 +137,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad); ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool, REGISTER_OP_CPU_KERNEL(
ops::UnpoolKernel<paddle::platform::CPUPlace, float>, unpool,ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>); ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(unpool_grad, REGISTER_OP_CPU_KERNEL(
ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>, unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>); ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
...@@ -27,7 +27,7 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -27,7 +27,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
auto * out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpooling_type"); std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -52,7 +52,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -52,7 +52,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
const framework::Tensor* out_grad = const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad = framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X")); context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpooling_type = context.Attr<std::string>("unpooling_type"); std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
} }
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward; math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out,
*out, *out_grad, in_x_grad); *out_grad, in_x_grad);
} }
}; };
......
...@@ -52,14 +52,16 @@ class TestUnpoolOp(OpTest): ...@@ -52,14 +52,16 @@ class TestUnpoolOp(OpTest):
c_start + arg % self.ksize[1] c_start + arg % self.ksize[1]
output = self.unpool2d_forward_naive(input, indices, self.ksize, \ output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32") self.strides, self.paddings).astype("float32")
self.inputs = {'X': input.astype('float32'), self.inputs = {
'Indices': indices.astype('int32')} 'X': input.astype('float32'),
'Indices': indices.astype('int32')
}
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'unpooling_type': self.unpooling_type, 'unpooling_type': self.unpooling_type,
} }
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
def test_check_output(self): def test_check_output(self):
...@@ -76,7 +78,5 @@ class TestUnpoolOp(OpTest): ...@@ -76,7 +78,5 @@ class TestUnpoolOp(OpTest):
self.strides = [2, 2] self.strides = [2, 2]
self.paddings = [0, 0] self.paddings = [0, 0]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册