未验证 提交 aa36c6aa 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move vol2col from fluid to phi (#48175)

* move vol2col from fluid to phi

* update copyright year
上级 48d5c36b
......@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle {
namespace operators {
......
......@@ -44,7 +44,6 @@ endif()
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
math_library(prelu)
math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
......@@ -84,7 +85,7 @@ void testVol2col() {
output_width},
*place);
paddle::operators::math::Vol2ColFunctor<DeviceContext, float> vol2col;
phi::funcs::Vol2ColFunctor<DeviceContext, float> vol2col;
vol2col(*context, input, dilations, strides, paddings, &output);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
......@@ -110,7 +111,7 @@ void testVol2col() {
paddle::framework::TensorCopySync(input_tmp, *place, &input);
}
paddle::operators::math::Col2VolFunctor<DeviceContext, float> col2vol;
phi::funcs::Col2VolFunctor<DeviceContext, float> col2vol;
col2vol(*context, output, dilations, strides, paddings, &input);
float* in_ptr;
......@@ -201,7 +202,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
output_width},
*place);
paddle::operators::math::Vol2ColFunctor<phi::GPUContext, float> vol2col;
phi::funcs::Vol2ColFunctor<phi::GPUContext, float> vol2col;
vol2col(*context, input, dilations, strides, paddings, &output);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
......@@ -227,7 +228,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
paddle::framework::TensorCopySync(input_tmp, *place, &input);
}
paddle::operators::math::Col2VolFunctor<phi::GPUContext, float> col2vol;
phi::funcs::Col2VolFunctor<phi::GPUContext, float> col2vol;
col2vol(*context, output, dilations, strides, paddings, &input);
float* in_ptr;
......
......@@ -17,6 +17,7 @@ math_library(segment_pooling)
math_library(sequence2batch)
math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function)
math_library(cross_entropy)
math_library(vol2col)
cc_library(
phi_data_layout_transform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
/*
* vol = [input_channels, input_depth, input_height, input_width]
......@@ -38,13 +37,13 @@ class Vol2ColFunctor<phi::CPUContext, T> {
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(),
4,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(),
7,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col->dims().size()));
......@@ -81,7 +80,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_depth_tmp,
output_depth,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp,
output_depth));
......@@ -92,7 +91,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_height_tmp,
output_height,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp,
output_height));
......@@ -103,7 +102,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_width_tmp,
output_width,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp,
output_width));
......@@ -164,13 +163,13 @@ class Col2VolFunctor<phi::CPUContext, T> {
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(),
4,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(),
7,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col.dims().size()));
......@@ -206,7 +205,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_depth_tmp,
output_depth,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp,
output_depth));
......@@ -217,7 +216,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_height_tmp,
output_height,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp,
output_height));
......@@ -228,7 +227,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
input_width_tmp,
output_width,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp,
output_width));
......@@ -278,6 +277,5 @@ template class Vol2ColFunctor<phi::CPUContext, double>;
template class Col2VolFunctor<phi::CPUContext, float>;
template class Col2VolFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -15,14 +15,13 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <class T>
__global__ void vol2col(int num_kernels,
......@@ -112,12 +111,12 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(),
4,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(),
7,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col->dims().size()));
......@@ -149,7 +148,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1;
PADDLE_ENFORCE_EQ(input_depth_tmp,
output_depth,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp,
output_depth));
......@@ -160,7 +159,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ(
input_height_tmp,
output_height,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp,
output_height));
......@@ -170,7 +169,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1;
PADDLE_ENFORCE_EQ(input_width_tmp,
output_width,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp,
output_width));
......@@ -180,7 +179,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
int max_threads = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads);
phi::backends::gpu::ChangeThreadNum(context, &max_threads);
#endif
const int threads = max_threads;
......@@ -318,12 +317,12 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(),
4,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(),
7,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col.dims().size()));
......@@ -356,7 +355,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1;
PADDLE_ENFORCE_EQ(input_depth_tmp,
output_depth,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp,
output_depth));
......@@ -367,7 +366,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ(
input_height_tmp,
output_height,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp,
output_height));
......@@ -377,7 +376,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1;
PADDLE_ENFORCE_EQ(input_width_tmp,
output_width,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp,
output_width));
......@@ -386,7 +385,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
int max_threads = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads);
phi::backends::gpu::ChangeThreadNum(context, &max_threads);
#endif
const int threads = max_threads;
......@@ -423,6 +422,5 @@ template class Vol2ColFunctor<phi::GPUContext, double>;
template class Col2VolFunctor<phi::GPUContext, float>;
template class Col2VolFunctor<phi::GPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -16,13 +16,11 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
using DataLayout = phi::DataLayout;
......@@ -92,6 +90,5 @@ class Col2VolFunctor {
const DataLayout data_layout = DataLayout::kNCHW) const;
};
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -15,11 +15,11 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi {
......@@ -147,7 +147,7 @@ void ConvGradKernel(const Context& dev_ctx,
if (is_expand) {
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
}
paddle::operators::math::Col2VolFunctor<Context, T> col2vol;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
......@@ -206,7 +206,7 @@ void ConvGradKernel(const Context& dev_ctx,
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch =
transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -381,7 +381,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
if (is_expand) {
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
}
paddle::operators::math::Col2VolFunctor<Context, T> col2vol;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
......@@ -431,7 +431,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -480,7 +480,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor ddy_batch =
transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
......
......@@ -15,12 +15,12 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi {
......@@ -133,7 +133,7 @@ void ConvKernelImpl(const Context& dev_ctx,
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
......
......@@ -15,7 +15,6 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_grad_kernel.h"
......@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi {
......@@ -146,7 +146,7 @@ void ConvTransposeGradRawKernel(const Context& ctx,
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
phi::funcs::Vol2ColFunctor<Context, T> vol2col;
funcs::ConcatFunctor<Context, T> concat_functor;
if (dx) {
......
......@@ -15,7 +15,6 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_kernel.h"
......@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi {
......@@ -139,7 +139,7 @@ void ConvTransposeRawKernel(const Context& ctx,
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
paddle::operators::math::Col2VolFunctor<Context, T> col2vol;
phi::funcs::Col2VolFunctor<Context, T> col2vol;
funcs::ConcatFunctor<Context, T> concat_functor;
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册