提交 bd561384 编写于 作者: S sweetsky0901

format code

上级 d9673cad
......@@ -17,15 +17,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
// All tensors are in NCHW format
template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -51,9 +49,6 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
}
}
};
template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public:
......@@ -62,7 +57,7 @@ public:
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -89,12 +84,10 @@ public:
}
}
};
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -18,11 +18,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data,
const int * indices_data,
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const int* indices_data,
const int input_height,
const int input_width,
const int channels,
......@@ -46,8 +44,7 @@ __global__ void KernelUnpool2dMax(const int nthreads,
}
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data,
__global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data,
const int* indices_data,
const int input_height,
const int input_width,
......@@ -78,11 +75,11 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
*/
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -107,13 +104,13 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
*/
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -133,17 +130,13 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
.stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels,
output_data, output_grad_data,
output_height, output_width,
input_grad_data);
output_height, output_width, input_grad_data);
}
};
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -22,22 +22,21 @@ namespace math {
template <typename Place, typename T>
class Unpool2dMaxFunctor {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output);
const framework::Tensor& indices, framework::Tensor* output);
};
template <typename Place, class T>
class Unpool2dMaxGradFunctor {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad);
framework::Tensor* input_grad);
};
} // namespace math
} // namespace operators
......
......@@ -21,32 +21,39 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
AddInput(
"X",
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput("Indices",
AddInput(
"Indices",
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
AddOutput(
"Out",
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::vector<int>>("ksize",
AddAttr<std::vector<int>>(
"ksize",
"(vector), the unpooling window size(height, width) "
"of unpooling operator.");
AddAttr<std::vector<int>>("strides",
AddAttr<std::vector<int>>(
"strides",
"(vector, default:{1, 1}), "
"strides (height, width) of unpooling operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
AddAttr<std::vector<int>>(
"paddings",
"(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator.")
.SetDefault({0, 0});
AddAttr<std::string>("unpooling_type",
AddAttr<std::string>(
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
......@@ -64,12 +71,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
};
int OutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size -1) * stride - 2 * padding + ksize;
int output_size = (input_size - 1) * stride - 2 * padding + ksize;
return output_size;
}
class UnpoolOp : public framework::OperatorWithKernel {
protected:
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -77,7 +84,7 @@ protected:
ctx.device_context());
}
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
......@@ -92,7 +99,8 @@ public:
ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> paddings =
ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput must be of 4-dimensional.");
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
......@@ -129,10 +137,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
REGISTER_OP_CPU_KERNEL(
unpool,ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
REGISTER_OP_CPU_KERNEL(
unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
......@@ -27,7 +27,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
auto * out = context.Output<framework::Tensor>("Out");
auto* out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y,
*out, *out_grad, in_x_grad);
unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out,
*out_grad, in_x_grad);
}
};
......
......@@ -52,8 +52,10 @@ class TestUnpoolOp(OpTest):
c_start + arg % self.ksize[1]
output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32")
self.inputs = {'X': input.astype('float32'),
'Indices': indices.astype('int32')}
self.inputs = {
'X': input.astype('float32'),
'Indices': indices.astype('int32')
}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
......@@ -76,7 +78,5 @@ class TestUnpoolOp(OpTest):
self.strides = [2, 2]
self.paddings = [0, 0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册