提交 32d7e618 编写于 作者: H hedaoyuan

Fix some bugs.

上级 1a615b48
......@@ -24,7 +24,8 @@ namespace math {
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<kCFO, platform::CPUPlace, T> {
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -75,7 +76,8 @@ class Im2ColFunctor<kCFO, platform::CPUPlace, T> {
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<kCFO, platform::CPUPlace, T> {
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -117,10 +119,14 @@ class Col2ImFunctor<kCFO, platform::CPUPlace, T> {
}
};
template class Im2ColFunctor<kCFO, platform::CPUPlace, float>;
template class Im2ColFunctor<kCFO, platform::CPUPlace, double>;
template class Col2ImFunctor<kCFO, platform::CPUPlace, float>;
template class Col2ImFunctor<kCFO, platform::CPUPlace, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, double>;
/*
* im = [input_channels, input_height, input_width]
......@@ -128,7 +134,8 @@ template class Col2ImFunctor<kCFO, platform::CPUPlace, double>;
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<kOCF, platform::CPUPlace, T> {
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -187,7 +194,8 @@ class Im2ColFunctor<kOCF, platform::CPUPlace, T> {
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<kOCF, platform::CPUPlace, T> {
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -238,10 +246,14 @@ class Col2ImFunctor<kOCF, platform::CPUPlace, T> {
}
};
template class Im2ColFunctor<kOCF, platform::CPUPlace, float>;
template class Im2ColFunctor<kOCF, platform::CPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::CPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::CPUPlace, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -61,7 +61,8 @@ __global__ void im2col(const T* data_im, int num_outs, int height, int width,
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -145,7 +146,8 @@ __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -182,10 +184,14 @@ class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
}
};
template class Im2ColFunctor<kCFO, platform::GPUPlace, float>;
template class Im2ColFunctor<kCFO, platform::GPUPlace, double>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, float>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
......@@ -226,7 +232,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -308,7 +315,8 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
......@@ -352,10 +360,14 @@ class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
}
};
template class Im2ColFunctor<kOCF, platform::GPUPlace, float>;
template class Im2ColFunctor<kOCF, platform::GPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, double>;
} // namespace math
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册