提交 6fc9a9fd 编写于 作者: S sweetsky0901

modify for del T2 and doc update

上级 ee0a794c
...@@ -19,8 +19,8 @@ namespace operators { ...@@ -19,8 +19,8 @@ namespace operators {
namespace math { namespace math {
// All tensors are in NCHW format // All tensors are in NCHW format
template <typename T, typename T2> template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -35,7 +35,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { ...@@ -35,7 +35,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
int input_feasize = input_height * input_width; int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width; int output_feasize = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T2 * indices_data = indices.data<T2>(); const int * indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -54,8 +54,8 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { ...@@ -54,8 +54,8 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
template <class T, typename T2> template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T, T2> { class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
const int output_width = output.dims()[3]; const int output_width = output.dims()[3];
int input_feasize = input_height * input_width; int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width; int output_feasize = output_height * output_width;
const T2 * indices_data = indices.data<T2>(); const int * indices_data = indices.data<int>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
...@@ -90,10 +90,10 @@ public: ...@@ -90,10 +90,10 @@ public:
} }
}; };
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float, int>; template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double, int>; template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float, int>; template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double, int>; template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -19,10 +19,10 @@ namespace paddle { ...@@ -19,10 +19,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T, typename T2> template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads, __global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data, const T* input_data,
const T2 * indices_data, const int * indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
const int channels, const int channels,
...@@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, ...@@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads,
output_data[out_offset + out_index] = input_data[i]; output_data[out_offset + out_index] = input_data[i];
} }
} }
template <typename T, typename T2> template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads, __global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data, const T* input_data,
const T2* indices_data, const int* indices_data,
const int input_height, const int input_height,
const int input_width, const int input_width,
const int channels, const int channels,
...@@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, ...@@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
*/ */
template <typename T, typename T2> template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -90,15 +90,14 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { ...@@ -90,15 +90,14 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
const int output_height = output->dims()[2]; const int output_height = output->dims()[2];
const int output_width = output->dims()[3]; const int output_width = output->dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T2 * indices_data = indices.data<T2>(); const int * indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax< KernelUnpool2dMax<
T, T2><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, indices_data, .stream()>>>(input.numel(), input_data, indices_data,
input_height, input_width, output_channels, input_height, input_width, output_channels,
output_data, output_height, output_width); output_data, output_height, output_width);
} }
...@@ -106,8 +105,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { ...@@ -106,8 +105,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
*/ */
template <typename T, typename T2> template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -122,18 +121,16 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { ...@@ -122,18 +121,16 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
const int output_height = output.dims()[2]; const int output_height = output.dims()[2];
const int output_width = output.dims()[3]; const int output_width = output.dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T2 * indices_data = indices.data<T2>(); const int * indices_data = indices.data<int>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int threads = 1024; int threads = 1024;
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad< KernelUnpool2dMaxGrad<
T, T2><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(input.numel(), input_data, indices_data,
nthreads, input_data, indices_data,
input_height, input_width, output_channels, input_height, input_width, output_channels,
output_data, output_grad_data, output_data, output_grad_data,
output_height, output_width, output_height, output_width,
...@@ -141,11 +138,11 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { ...@@ -141,11 +138,11 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
} }
}; };
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float, int>; template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double, int>; template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float, int>; template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double, int>; template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename Place, typename T, typename T2> template <typename Place, typename T>
class Unpool2dMaxFunctor { class Unpool2dMaxFunctor {
public: public:
...@@ -29,7 +29,7 @@ class Unpool2dMaxFunctor { ...@@ -29,7 +29,7 @@ class Unpool2dMaxFunctor {
framework::Tensor * output); framework::Tensor * output);
}; };
template <typename Place, class T, typename T2> template <typename Place, class T>
class Unpool2dMaxGradFunctor { class Unpool2dMaxGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
......
...@@ -50,10 +50,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,10 +50,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(string), unpooling type, can be \"max\" for max-unpooling ") "(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"}); .InEnum({"max"});
AddComment(R"DOC( AddComment(R"DOC(
"Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 "Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf /07/iccv2011.pdf
PyTorch: http://pytorch.org/docs/master/nn.html?highlight=unpool#
torch.nn.MaxUnpool2d"
)DOC"); )DOC");
} }
}; };
...@@ -125,9 +130,9 @@ namespace ops = paddle::operators; ...@@ -125,9 +130,9 @@ namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad); ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool, REGISTER_OP_CPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::CPUPlace, float, int>, ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double, int>); ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(unpool_grad, REGISTER_OP_CPU_KERNEL(unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, float, int>, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double, int>); ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
...@@ -16,10 +16,10 @@ limitations under the License. */ ...@@ -16,10 +16,10 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool, REGISTER_OP_GPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::GPUPlace, float, int>, ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double, int>); ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(unpool_grad, REGISTER_OP_GPU_KERNEL(unpool_grad,
ops::UnpoolGradKernel<paddle::platform::GPUPlace, ops::UnpoolGradKernel<paddle::platform::GPUPlace,
float, int>, float>,
ops::UnpoolGradKernel<paddle::platform::GPUPlace, ops::UnpoolGradKernel<paddle::platform::GPUPlace,
double, int>); double>);
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename T2> template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> { class UnpoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0)); set_zero(context.device_context(), out, static_cast<T>(0));
} }
math::Unpool2dMaxFunctor<Place, T, T2> unpool2d_max_forward; math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
} }
}; };
template <typename Place, typename T, typename T2> template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> { class UnpoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
} }
math::Unpool2dMaxGradFunctor<Place, T, T2> unpool2d_max_backward; math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, unpool2d_max_backward(context.device_context(), *in_x, *in_y,
*out, *out_grad, in_x_grad); *out, *out_grad, in_x_grad);
} }
......
...@@ -50,7 +50,7 @@ class TestUnpoolOp(OpTest): ...@@ -50,7 +50,7 @@ class TestUnpoolOp(OpTest):
indices[nidx, cidx, i, j] = \ indices[nidx, cidx, i, j] = \
(r_start + arg / self.ksize[1]) * wsize + \ (r_start + arg / self.ksize[1]) * wsize + \
c_start + arg % self.ksize[1] c_start + arg % self.ksize[1]
output = self.Unpool2d_forward_naive(input, indices, self.ksize, \ output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32") self.strides, self.paddings).astype("float32")
self.inputs = {'X': input.astype('float32'), self.inputs = {'X': input.astype('float32'),
'Indices': indices.astype('int32')} 'Indices': indices.astype('int32')}
...@@ -69,7 +69,7 @@ class TestUnpoolOp(OpTest): ...@@ -69,7 +69,7 @@ class TestUnpoolOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_test_case(self): def init_test_case(self):
self.Unpool2d_forward_naive = unpool2dmax_forward_naive self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max" self.unpooling_type = "max"
self.shape = [6, 4, 5, 5] self.shape = [6, 4, 5, 5]
self.ksize = [3, 3] self.ksize = [3, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册