提交 36da8255 编写于 作者: C chengduoZH

Add code comments

上级 e21e5646
...@@ -18,6 +18,11 @@ namespace paddle { ...@@ -18,6 +18,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> { class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
...@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent height
* and width, respectively.
*/
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> { class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
...@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <class T> template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> { class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public: public:
...@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> { ...@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
}; };
template class MaxPool2dGradFunctor<platform::CPUPlace, float>; template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>; template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace, template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
...@@ -216,6 +231,11 @@ template class Pool2dGradFunctor< ...@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
template class Pool2dGradFunctor< template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> { class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
...@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> { class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
...@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <class T> template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> { class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public: public:
...@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> { ...@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
}; };
template class MaxPool3dGradFunctor<platform::CPUPlace, float>; template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>; template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace, template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
...@@ -459,6 +489,11 @@ template class Pool3dGradFunctor< ...@@ -459,6 +489,11 @@ template class Pool3dGradFunctor<
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public: public:
...@@ -519,6 +554,11 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -519,6 +554,11 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public: public:
...@@ -563,6 +603,11 @@ template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>; ...@@ -563,6 +603,11 @@ template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>; template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>; template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public: public:
...@@ -637,6 +682,11 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -637,6 +682,11 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public: public:
......
...@@ -149,6 +149,11 @@ __global__ void KernelMaxPool2DGrad( ...@@ -149,6 +149,11 @@ __global__ void KernelMaxPool2DGrad(
} }
} }
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> { class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
public: public:
...@@ -190,6 +195,11 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> { ...@@ -190,6 +195,11 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> { class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
public: public:
...@@ -234,6 +244,11 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> { ...@@ -234,6 +244,11 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool2dGradFunctor<platform::GPUPlace, T> { class MaxPool2dGradFunctor<platform::GPUPlace, T> {
public: public:
...@@ -456,6 +471,11 @@ __global__ void KernelMaxPool3DGrad( ...@@ -456,6 +471,11 @@ __global__ void KernelMaxPool3DGrad(
} }
} }
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> { class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
public: public:
...@@ -504,6 +524,11 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> { ...@@ -504,6 +524,11 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> { class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
public: public:
...@@ -556,6 +581,11 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> { ...@@ -556,6 +581,11 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <class T> template <class T>
class MaxPool3dGradFunctor<platform::GPUPlace, T> { class MaxPool3dGradFunctor<platform::GPUPlace, T> {
public: public:
...@@ -709,6 +739,11 @@ __global__ void KernelMaxPool2DWithIdxGrad( ...@@ -709,6 +739,11 @@ __global__ void KernelMaxPool2DWithIdxGrad(
} }
} }
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
public: public:
...@@ -750,6 +785,11 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -750,6 +785,11 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
} }
}; };
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> { class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
public: public:
...@@ -903,6 +943,11 @@ __global__ void KernelMaxPool3DWithIdxGrad( ...@@ -903,6 +943,11 @@ __global__ void KernelMaxPool3DWithIdxGrad(
} }
} }
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
public: public:
...@@ -951,6 +996,11 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -951,6 +996,11 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
} }
}; };
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T> template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册