未验证 提交 88cac16b 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move im2col from fluid to phi (#48174)

* decouple im2col from fluid

* move im2col to phi

* fix build error

* delete redundant comment
上级 6ea8bfc6
......@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle {
......
......@@ -20,7 +20,7 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -101,7 +101,8 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
kernels[1]});
offset_out += output_height[i] * output_width[i];
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T>
f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
......@@ -135,7 +136,8 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
kernels[0],
kernels[1]});
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T>
f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
......@@ -190,7 +192,8 @@ class Im2SequenceGradKernel : public framework::OpKernel<T> {
d_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
const Tensor src = d_out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, kernels[0], kernels[1]});
math::Col2ImFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T>
f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
......
......@@ -23,7 +23,6 @@ endif()
math_library(context_project DEPS im2col math_function)
math_library(cos_sim_functor)
math_library(depthwise_conv)
math_library(im2col)
math_library(sample_prob)
math_library(sampler DEPS generator)
......
......@@ -18,8 +18,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/im2col.h"
namespace paddle {
namespace operators {
......@@ -100,7 +100,8 @@ class ContextProjectFunctor {
phi::DenseTensor* col) {
auto lod_level_0 = in.lod()[0];
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, float> im2col_ocf;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, float>
im2col_ocf;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
......@@ -230,7 +231,8 @@ class ContextProjectGradFunctor {
phi::DenseTensor* col) {
auto lod_level_0 = in.lod()[0];
math::Col2ImFunctor<math::ColFormat::kOCF, DeviceContext, float> col2im_ocf;
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, float>
col2im_ocf;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
......
......@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include <gtest/gtest.h>
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/kernels/funcs/im2col_cfo_cpu.h"
template <typename DeviceContext, typename Place>
void testIm2col() {
......@@ -76,15 +77,9 @@ void testIm2col() {
{output_height, output_width, 1, filter_size, filter_size}, *place);
// Im2Col
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO,
DeviceContext,
float>
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, float>
im2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF,
DeviceContext,
float>
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, float>
im2col_ocf;
im2col(*context, input, dilation, stride, padding, &output_cfo);
......@@ -119,15 +114,9 @@ void testIm2col() {
}
// Col2Im: kCFO
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO,
DeviceContext,
float>
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, float>
col2im;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF,
DeviceContext,
float>
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, float>
col2im_ocf;
float col2im_data[] = {0, 2, 2, 3, 8, 5};
......@@ -237,15 +226,9 @@ void testIm2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
{output_height, output_width, 1, filter_size, filter_size}, *place);
// Im2Col
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO,
phi::GPUContext,
float>
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, phi::GPUContext, float>
im2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF,
phi::GPUContext,
float>
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kOCF, phi::GPUContext, float>
im2col_ocf;
im2col(*context, input, dilation, stride, padding, &output_cfo);
......@@ -280,15 +263,9 @@ void testIm2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
}
// Col2Im: kCFO
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO,
phi::GPUContext,
float>
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, phi::GPUContext, float>
col2im;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF,
phi::GPUContext,
float>
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kOCF, phi::GPUContext, float>
col2im_ocf;
float col2im_data[] = {0, 2, 2, 3, 8, 5};
......@@ -363,18 +340,15 @@ TEST(math, im2col) {
int output_width = (iw - fw + padding[1] * 2) / stride[1] + 1; \
out.mutable_data<float>({ic, fh, fw, output_height, output_width}, place); \
ref.mutable_data<float>({ic, fh, fw, output_height, output_width}, place); \
paddle::operators::math::Im2ColFunctor< \
paddle::operators::math::ColFormat::kCFO, \
phi::CPUContext, \
float> \
im2col
phi::funcs:: \
Im2ColFunctor<phi::funcs::ColFormat::kCFO, phi::CPUContext, float> \
im2col
void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
PREPARE_IM2COL_CPU;
im2col(context, input, dilation, stride, padding, &out);
paddle::operators::math::im2col_common<float>(
input, dilation, stride, padding, &ref);
phi::funcs::im2col_common<float>(input, dilation, stride, padding, &ref);
float* ref_data = ref.data<float>();
float* out_data = out.data<float>();
......@@ -398,8 +372,7 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
auto t2 = GetCurrentMs();
for (int i = 0; i < repeat; ++i) {
paddle::operators::math::im2col_common<float>(
input, dilation, stride, padding, &ref);
phi::funcs::im2col_common<float>(input, dilation, stride, padding, &ref);
}
auto t3 = GetCurrentMs();
......
......@@ -17,6 +17,7 @@ math_library(segment_pooling)
math_library(sequence2batch)
math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function)
math_library(cross_entropy)
math_library(im2col)
math_library(vol2col)
cc_library(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/im2col_cfo_cpu.h"
namespace phi {
class CPUContext;
} // namespace phi
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
/*
* im = [input_channels, input_height, input_width]
......@@ -30,9 +28,7 @@ namespace math {
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
DeviceContext,
T> {
class Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& im,
......@@ -43,13 +39,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im.dims()));
PADDLE_ENFORCE_EQ(col->dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col->dims()));
......@@ -77,9 +73,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
DeviceContext,
T> {
class Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& col,
......@@ -90,13 +84,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im->dims()));
PADDLE_ENFORCE_EQ(col.dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col.dims()));
......@@ -111,22 +105,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
platform::errors::InvalidArgument(
"Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
platform::errors::InvalidArgument(
"Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
phi::errors::InvalidArgument("Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
phi::errors::InvalidArgument("Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
int channels_col = im_channels * filter_height * filter_width;
......@@ -160,16 +154,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
double>;
......@@ -179,9 +173,7 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T, typename DeviceContext>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
DeviceContext,
T> {
class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& im,
......@@ -192,13 +184,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im.dims()));
PADDLE_ENFORCE_EQ(col->dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col->dims()));
......@@ -254,9 +246,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T, typename DeviceContext>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
DeviceContext,
T> {
class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& col,
......@@ -267,13 +257,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im->dims()));
PADDLE_ENFORCE_EQ(col.dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col.dims()));
......@@ -288,14 +278,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
col_height,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Output_height and padding(padding_up, padding_down) "
"are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
col_width,
platform::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
phi::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
T* im_data = im->data<T>();
const T* col_data = col.data<T>();
......@@ -335,18 +325,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,14 +15,13 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/im2col.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <class T>
__global__ void im2col(const T* data_im,
......@@ -86,9 +85,7 @@ __global__ void im2col(const T* data_im,
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
DeviceContext,
T> {
class Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& im,
......@@ -99,13 +96,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im.dims()));
PADDLE_ENFORCE_EQ(col->dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col->dims()));
......@@ -124,7 +121,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int num_outputs = im_channels * col_height * col_width;
int num_thread = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &num_thread);
phi::backends::gpu::ChangeThreadNum(context, &num_thread);
#endif
int blocks = (num_outputs + num_thread - 1) / num_thread;
int block_x = 512;
......@@ -223,9 +220,7 @@ __global__ void col2im(int n,
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
DeviceContext,
T> {
class Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& col,
......@@ -236,13 +231,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im->dims()));
PADDLE_ENFORCE_EQ(col.dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col.dims()));
......@@ -258,28 +253,28 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
platform::errors::InvalidArgument(
"Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
phi::errors::InvalidArgument("Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
platform::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
phi::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
size_t num_kernels = im_channels * im_height * im_width;
int num_thread = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &num_thread);
phi::backends::gpu::ChangeThreadNum(context, &num_thread);
#endif
size_t blocks = (num_kernels + num_thread - 1) / num_thread;
size_t block_x = 512;
......@@ -308,16 +303,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
......@@ -367,9 +362,7 @@ __global__ void im2colOCF(const T* im_data,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
DeviceContext,
T> {
class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& im,
......@@ -380,13 +373,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im.dims()));
PADDLE_ENFORCE_EQ(col->dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col->dims()));
......@@ -479,9 +472,7 @@ __global__ void col2imOCF(const T* col_data,
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
DeviceContext,
T> {
class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
public:
void operator()(const DeviceContext& context,
const phi::DenseTensor& col,
......@@ -492,13 +483,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(),
3,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'im' should be 3. But got "
"the dims of tensor 'im' is [%s].",
im->dims()));
PADDLE_ENFORCE_EQ(col.dims().size(),
5,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of tensor 'col' should be 5. But got "
"the dims of tensor 'col' is [%s].",
col.dims()));
......@@ -511,22 +502,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
platform::errors::InvalidArgument(
"Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
phi::errors::InvalidArgument("Output_height and padding(padding_up, "
"padding_down) are inconsistent."));
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
platform::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
phi::errors::InvalidArgument("col_width and padding(padding_left, "
"padding_right) are inconsistent."));
int block_dim_x = 0;
int block_dim_y = 0;
......@@ -563,20 +554,19 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -16,13 +16,13 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
using DataLayout = phi::DataLayout;
......@@ -107,6 +107,5 @@ class Col2ImFunctor {
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -16,11 +16,10 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
/**
* The most common im2col algorithm.
......@@ -317,6 +316,5 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im,
}
}
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -14,10 +14,10 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
......@@ -147,10 +147,8 @@ void ConvGradKernel(const Context& dev_ctx,
if (is_expand) {
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
}
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch =
......@@ -203,9 +201,7 @@ void ConvGradKernel(const Context& dev_ctx,
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch =
......@@ -381,10 +377,8 @@ void ConvGradGradKernel(const Context& dev_ctx,
if (is_expand) {
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
}
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
for (int i = 0; i < batch_size; i++) {
DenseTensor dy_batch =
......@@ -428,9 +422,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
set_zero(dev_ctx, dW, static_cast<T>(0));
DenseTensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape);
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor dy_batch =
......@@ -477,9 +469,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
}
set_zero(dev_ctx, &transformed_ddY, static_cast<T>(0));
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor ddy_batch =
......
......@@ -14,11 +14,11 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
......@@ -133,10 +133,8 @@ void ConvKernelImpl(const Context& dev_ctx,
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
......
......@@ -14,13 +14,13 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_grad_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
......@@ -143,9 +143,7 @@ void ConvTransposeGradRawKernel(const Context& ctx,
DenseTensor dfilter_;
funcs::SetConstant<Context, T> set_zero;
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
funcs::ConcatFunctor<Context, T> concat_functor;
......
......@@ -14,13 +14,13 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
......@@ -136,9 +136,7 @@ void ConvTransposeRawKernel(const Context& ctx,
(data_layout != DataLayout::kNHWC
? static_cast<int>(out_dims[1]) / groups
: static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
funcs::ConcatFunctor<Context, T> concat_functor;
......
......@@ -16,8 +16,8 @@
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
......@@ -60,9 +60,7 @@ void FoldGradKernel(const Context& ctx,
output_height,
output_width});
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch = out_grad.Slice(i, i + 1).Resize(out_shape);
......
......@@ -16,9 +16,9 @@
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
......@@ -36,9 +36,7 @@ void FoldKernel(const Context& ctx,
const int batch_size = static_cast<int>(x.dims()[0]);
ctx.template Alloc<T>(out);
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
const auto& x_dims = x.dims();
int output_height = (output_sizes[0] + 2 * paddings[0] -
......
......@@ -16,8 +16,8 @@
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
......@@ -56,9 +56,7 @@ void UnfoldGradKernel(const Context& ctx,
DDim out_matrix_shape = make_ddim(
{x_dims[1], kernel_sizes[0], kernel_sizes[1], out_height, out_width});
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
phi::funcs::Col2ImFunctor<phi::funcs::ColFormat::kCFO, Context, T> col2im;
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, x_grad, static_cast<T>(0));
......
......@@ -16,8 +16,8 @@
#include <vector>
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
......@@ -34,9 +34,7 @@ void UnfoldKernel(const Context& ctx,
const int batch_size = static_cast<int>(x.dims()[0]);
ctx.template Alloc<T>(out);
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
const auto& x_dims = x.dims();
int out_height = phi::funcs::CalcOutputSize(x_dims[2],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册