提交 c52ed8de 编写于 作者: S sweetsky0901

format code

上级 bd561384
...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/unpooling.h" #include "paddle/operators/math/unpooling.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// All tensors are in NCHW format
template <typename T> template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> { class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(
const framework::Tensor& input, const platform::DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) { const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -51,13 +49,11 @@ public: ...@@ -51,13 +49,11 @@ public:
}; };
template <class T> template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> { class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(
const framework::Tensor& input, const platform::DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices, const framework::Tensor& output,
const framework::Tensor& output, const framework::Tensor& output_grad, framework::Tensor* input_grad) {
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
......
...@@ -19,14 +19,10 @@ namespace paddle { ...@@ -19,14 +19,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, __global__ void KernelUnpool2dMax(
const int* indices_data, const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_height, const int input_width, const int channels,
const int input_width, T* output_data, const int output_height, const int output_width) {
const int channels,
T* output_data,
const int output_height,
const int output_width) {
int in_n_stride = input_height * input_width * channels; int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width; int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels; int out_n_stride = output_height * output_width * channels;
...@@ -44,16 +40,11 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, ...@@ -44,16 +40,11 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
} }
} }
template <typename T> template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data, __global__ void KernelUnpool2dMaxGrad(
const int* indices_data, const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_height, const int input_width, const int channels,
const int input_width, const T* output_data, const T* output_grad, const int output_height,
const int channels, const int output_width, T* input_grad) {
const T* output_data,
const T* output_grad,
const int output_height,
const int output_width,
T* input_grad) {
int in_n_stride = input_height * input_width * channels; int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width; int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels; int out_n_stride = output_height * output_width * channels;
...@@ -75,11 +66,10 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data, ...@@ -75,11 +66,10 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data,
*/ */
template <typename T> template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> { class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(
const framework::Tensor& input, const platform::DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices, framework::Tensor* output) {
framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -91,12 +81,11 @@ public: ...@@ -91,12 +81,11 @@ public:
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax< KernelUnpool2dMax<T><<<grid, threads, 0,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input.numel(), input_data, indices_data, .stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels, input_height, input_width, output_channels,
output_data, output_height, output_width); output_data, output_height, output_width);
} }
}; };
/* /*
...@@ -104,7 +93,7 @@ public: ...@@ -104,7 +93,7 @@ public:
*/ */
template <typename T> template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
...@@ -124,13 +113,11 @@ public: ...@@ -124,13 +113,11 @@ public:
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad< KernelUnpool2dMaxGrad<T><<<grid, threads, 0,
T><<<grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context)
reinterpret_cast<const platform::CUDADeviceContext&>(context) .stream()>>>(input.numel(), input_data, indices_data,
.stream()>>>(input.numel(), input_data, indices_data, input_height, input_width, output_channels, output_data,
input_height, input_width, output_channels, output_grad_data, output_height, output_width, input_grad_data);
output_data, output_grad_data,
output_height, output_width, input_grad_data);
} }
}; };
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
......
...@@ -18,25 +18,20 @@ limitations under the License. */ ...@@ -18,25 +18,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename Place, typename T> template <typename Place, typename T>
class Unpool2dMaxFunctor { class Unpool2dMaxFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(
const framework::Tensor& input, const platform::DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output); const framework::Tensor& indices, framework::Tensor* output);
}; };
template <typename Place, class T> template <typename Place, class T>
class Unpool2dMaxGradFunctor { class Unpool2dMaxGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(
const framework::Tensor& input, const platform::DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices, const framework::Tensor& output,
const framework::Tensor& output, const framework::Tensor& output_grad, framework::Tensor* input_grad);
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -31,13 +31,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -31,13 +31,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) The input tensor of the indices given out by MaxPool2d. " "(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."); "number of channels, H and W is the height and width of feature.");
AddOutput( AddOutput("Out",
"Out", "(Tensor) The output tensor of unpool operator."
"(Tensor) The output tensor of unpool operator." "The format of output tensor is also NCHW."
"The format of output tensor is also NCHW." "Where N is batch size, C is "
"Where N is batch size, C is " "the number of channels, H and W is the height and "
"the number of channels, H and W is the height and " "width of feature.");
"width of feature.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"(vector), the unpooling window size(height, width) " "(vector), the unpooling window size(height, width) "
...@@ -138,7 +137,7 @@ namespace ops = paddle::operators; ...@@ -138,7 +137,7 @@ namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad); ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unpool,ops::UnpoolKernel<paddle::platform::CPUPlace, float>, unpool, ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>); ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>, unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
......
...@@ -15,11 +15,9 @@ limitations under the License. */ ...@@ -15,11 +15,9 @@ limitations under the License. */
#include "paddle/operators/unpool_op.h" #include "paddle/operators/unpool_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool, REGISTER_OP_GPU_KERNEL(
ops::UnpoolKernel<paddle::platform::GPUPlace, float>, unpool, ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double>); ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(unpool_grad, REGISTER_OP_GPU_KERNEL(
ops::UnpoolGradKernel<paddle::platform::GPUPlace, unpool_grad, ops::UnpoolGradKernel<paddle::platform::GPUPlace, float>,
float>, ops::UnpoolGradKernel<paddle::platform::GPUPlace, double>);
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
double>);
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> { class UnpoolKernel : public framework::OpKernel<T> {
public: public:
...@@ -41,7 +40,6 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -41,7 +40,6 @@ class UnpoolKernel : public framework::OpKernel<T> {
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> { class UnpoolGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -69,6 +67,5 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -69,6 +67,5 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
*out_grad, in_x_grad); *out_grad, in_x_grad);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册