未验证 提交 aa36c6aa 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move vol2col from fluid to phi (#48175)

* move vol2col from fluid to phi

* update copyright year
上级 48d5c36b
...@@ -23,8 +23,8 @@ limitations under the License. */ ...@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h" #include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -44,7 +44,6 @@ endif() ...@@ -44,7 +44,6 @@ endif()
math_library(matrix_bit_code) math_library(matrix_bit_code)
math_library(unpooling) math_library(unpooling)
math_library(vol2col)
math_library(prelu) math_library(prelu)
math_library(bert_encoder_functor) math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function) math_library(tree2col DEPS math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/phi/kernels/funcs/vol2col.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -84,7 +85,7 @@ void testVol2col() { ...@@ -84,7 +85,7 @@ void testVol2col() {
output_width}, output_width},
*place); *place);
paddle::operators::math::Vol2ColFunctor<DeviceContext, float> vol2col; phi::funcs::Vol2ColFunctor<DeviceContext, float> vol2col;
vol2col(*context, input, dilations, strides, paddings, &output); vol2col(*context, input, dilations, strides, paddings, &output);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
...@@ -110,7 +111,7 @@ void testVol2col() { ...@@ -110,7 +111,7 @@ void testVol2col() {
paddle::framework::TensorCopySync(input_tmp, *place, &input); paddle::framework::TensorCopySync(input_tmp, *place, &input);
} }
paddle::operators::math::Col2VolFunctor<DeviceContext, float> col2vol; phi::funcs::Col2VolFunctor<DeviceContext, float> col2vol;
col2vol(*context, output, dilations, strides, paddings, &input); col2vol(*context, output, dilations, strides, paddings, &input);
float* in_ptr; float* in_ptr;
...@@ -201,7 +202,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() { ...@@ -201,7 +202,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
output_width}, output_width},
*place); *place);
paddle::operators::math::Vol2ColFunctor<phi::GPUContext, float> vol2col; phi::funcs::Vol2ColFunctor<phi::GPUContext, float> vol2col;
vol2col(*context, input, dilations, strides, paddings, &output); vol2col(*context, input, dilations, strides, paddings, &output);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
...@@ -227,7 +228,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() { ...@@ -227,7 +228,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
paddle::framework::TensorCopySync(input_tmp, *place, &input); paddle::framework::TensorCopySync(input_tmp, *place, &input);
} }
paddle::operators::math::Col2VolFunctor<phi::GPUContext, float> col2vol; phi::funcs::Col2VolFunctor<phi::GPUContext, float> col2vol;
col2vol(*context, output, dilations, strides, paddings, &input); col2vol(*context, output, dilations, strides, paddings, &input);
float* in_ptr; float* in_ptr;
......
...@@ -17,6 +17,7 @@ math_library(segment_pooling) ...@@ -17,6 +17,7 @@ math_library(segment_pooling)
math_library(sequence2batch) math_library(sequence2batch)
math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function)
math_library(cross_entropy) math_library(cross_entropy)
math_library(vol2col)
cc_library( cc_library(
phi_data_layout_transform phi_data_layout_transform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/phi/kernels/funcs/vol2col.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
/* /*
* vol = [input_channels, input_depth, input_height, input_width] * vol = [input_channels, input_depth, input_height, input_width]
...@@ -38,13 +37,13 @@ class Vol2ColFunctor<phi::CPUContext, T> { ...@@ -38,13 +37,13 @@ class Vol2ColFunctor<phi::CPUContext, T> {
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), PADDLE_ENFORCE_EQ(vol.dims().size(),
4, 4,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol.dims().size())); vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), PADDLE_ENFORCE_EQ(col->dims().size(),
7, 7,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col->dims().size())); col->dims().size()));
...@@ -81,7 +80,7 @@ class Vol2ColFunctor<phi::CPUContext, T> { ...@@ -81,7 +80,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_depth_tmp, input_depth_tmp,
output_depth, output_depth,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, input_depth_tmp,
output_depth)); output_depth));
...@@ -92,7 +91,7 @@ class Vol2ColFunctor<phi::CPUContext, T> { ...@@ -92,7 +91,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_height_tmp, input_height_tmp,
output_height, output_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, input_height_tmp,
output_height)); output_height));
...@@ -103,7 +102,7 @@ class Vol2ColFunctor<phi::CPUContext, T> { ...@@ -103,7 +102,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_width_tmp, input_width_tmp,
output_width, output_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, input_width_tmp,
output_width)); output_width));
...@@ -164,13 +163,13 @@ class Col2VolFunctor<phi::CPUContext, T> { ...@@ -164,13 +163,13 @@ class Col2VolFunctor<phi::CPUContext, T> {
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), PADDLE_ENFORCE_EQ(vol->dims().size(),
4, 4,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol->dims().size())); vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), PADDLE_ENFORCE_EQ(col.dims().size(),
7, 7,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col.dims().size())); col.dims().size()));
...@@ -206,7 +205,7 @@ class Col2VolFunctor<phi::CPUContext, T> { ...@@ -206,7 +205,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_depth_tmp, input_depth_tmp,
output_depth, output_depth,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, input_depth_tmp,
output_depth)); output_depth));
...@@ -217,7 +216,7 @@ class Col2VolFunctor<phi::CPUContext, T> { ...@@ -217,7 +216,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_height_tmp, input_height_tmp,
output_height, output_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, input_height_tmp,
output_height)); output_height));
...@@ -228,7 +227,7 @@ class Col2VolFunctor<phi::CPUContext, T> { ...@@ -228,7 +227,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_width_tmp, input_width_tmp,
output_width, output_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, input_width_tmp,
output_width)); output_width));
...@@ -278,6 +277,5 @@ template class Vol2ColFunctor<phi::CPUContext, double>; ...@@ -278,6 +277,5 @@ template class Vol2ColFunctor<phi::CPUContext, double>;
template class Col2VolFunctor<phi::CPUContext, float>; template class Col2VolFunctor<phi::CPUContext, float>;
template class Col2VolFunctor<phi::CPUContext, double>; template class Col2VolFunctor<phi::CPUContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -15,14 +15,13 @@ limitations under the License. */ ...@@ -15,14 +15,13 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <class T> template <class T>
__global__ void vol2col(int num_kernels, __global__ void vol2col(int num_kernels,
...@@ -112,12 +111,12 @@ void Vol2ColFunctor<DeviceContext, T>::operator()( ...@@ -112,12 +111,12 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), PADDLE_ENFORCE_EQ(vol.dims().size(),
4, 4,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol.dims().size())); vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), PADDLE_ENFORCE_EQ(col->dims().size(),
7, 7,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col->dims().size())); col->dims().size()));
...@@ -149,7 +148,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()( ...@@ -149,7 +148,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1; 1;
PADDLE_ENFORCE_EQ(input_depth_tmp, PADDLE_ENFORCE_EQ(input_depth_tmp,
output_depth, output_depth,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, input_depth_tmp,
output_depth)); output_depth));
...@@ -160,7 +159,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()( ...@@ -160,7 +159,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_height_tmp, input_height_tmp,
output_height, output_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, input_height_tmp,
output_height)); output_height));
...@@ -170,7 +169,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()( ...@@ -170,7 +169,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1; 1;
PADDLE_ENFORCE_EQ(input_width_tmp, PADDLE_ENFORCE_EQ(input_width_tmp,
output_width, output_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, input_width_tmp,
output_width)); output_width));
...@@ -180,7 +179,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()( ...@@ -180,7 +179,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
int max_threads = 1024; int max_threads = 1024;
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads); phi::backends::gpu::ChangeThreadNum(context, &max_threads);
#endif #endif
const int threads = max_threads; const int threads = max_threads;
...@@ -318,12 +317,12 @@ void Col2VolFunctor<DeviceContext, T>::operator()( ...@@ -318,12 +317,12 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), PADDLE_ENFORCE_EQ(vol->dims().size(),
4, 4,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol->dims().size())); vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), PADDLE_ENFORCE_EQ(col.dims().size(),
7, 7,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col.dims().size())); col.dims().size()));
...@@ -356,7 +355,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()( ...@@ -356,7 +355,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1; 1;
PADDLE_ENFORCE_EQ(input_depth_tmp, PADDLE_ENFORCE_EQ(input_depth_tmp,
output_depth, output_depth,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, input_depth_tmp,
output_depth)); output_depth));
...@@ -367,7 +366,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()( ...@@ -367,7 +366,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_height_tmp, input_height_tmp,
output_height, output_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, input_height_tmp,
output_height)); output_height));
...@@ -377,7 +376,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()( ...@@ -377,7 +376,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1; 1;
PADDLE_ENFORCE_EQ(input_width_tmp, PADDLE_ENFORCE_EQ(input_width_tmp,
output_width, output_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, input_width_tmp,
output_width)); output_width));
...@@ -386,7 +385,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()( ...@@ -386,7 +385,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
int max_threads = 1024; int max_threads = 1024;
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads); phi::backends::gpu::ChangeThreadNum(context, &max_threads);
#endif #endif
const int threads = max_threads; const int threads = max_threads;
...@@ -423,6 +422,5 @@ template class Vol2ColFunctor<phi::GPUContext, double>; ...@@ -423,6 +422,5 @@ template class Vol2ColFunctor<phi::GPUContext, double>;
template class Col2VolFunctor<phi::GPUContext, float>; template class Col2VolFunctor<phi::GPUContext, float>;
template class Col2VolFunctor<phi::GPUContext, double>; template class Col2VolFunctor<phi::GPUContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -16,13 +16,11 @@ limitations under the License. */ ...@@ -16,13 +16,11 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/core/errors.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
using DataLayout = phi::DataLayout; using DataLayout = phi::DataLayout;
...@@ -92,6 +90,5 @@ class Col2VolFunctor { ...@@ -92,6 +90,5 @@ class Col2VolFunctor {
const DataLayout data_layout = DataLayout::kNCHW) const; const DataLayout data_layout = DataLayout::kNCHW) const;
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi { namespace phi {
...@@ -147,7 +147,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -147,7 +147,7 @@ void ConvGradKernel(const Context& dev_ctx,
if (is_expand) { if (is_expand) {
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0)); set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
} }
paddle::operators::math::Col2VolFunctor<Context, T> col2vol; phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math:: paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im; col2im;
...@@ -206,7 +206,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -206,7 +206,7 @@ void ConvGradKernel(const Context& dev_ctx,
paddle::operators::math:: paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col; im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col; phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
DenseTensor out_grad_batch = DenseTensor out_grad_batch =
transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape); transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -381,7 +381,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -381,7 +381,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
if (is_expand) { if (is_expand) {
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0)); set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
} }
paddle::operators::math::Col2VolFunctor<Context, T> col2vol; phi::funcs::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math:: paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im; col2im;
...@@ -431,7 +431,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -431,7 +431,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle::operators::math:: paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col; im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col; phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
DenseTensor dy_batch = DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape); transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -480,7 +480,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -480,7 +480,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle::operators::math:: paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col; im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col; phi::funcs::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
DenseTensor ddy_batch = DenseTensor ddy_batch =
transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape); transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/conv_kernel.h" #include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi { namespace phi {
...@@ -133,7 +133,7 @@ void ConvKernelImpl(const Context& dev_ctx, ...@@ -133,7 +133,7 @@ void ConvKernelImpl(const Context& dev_ctx,
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups; int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_output.dims()[1]) / groups; int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col; phi::funcs::Vol2ColFunctor<Context, T> vol2col;
paddle::operators::math:: paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col; im2col;
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_grad_kernel.h" #include "paddle/phi/kernels/conv_transpose_grad_kernel.h"
...@@ -23,6 +22,7 @@ ...@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi { namespace phi {
...@@ -146,7 +146,7 @@ void ConvTransposeGradRawKernel(const Context& ctx, ...@@ -146,7 +146,7 @@ void ConvTransposeGradRawKernel(const Context& ctx,
paddle::operators::math:: paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col; im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col; phi::funcs::Vol2ColFunctor<Context, T> vol2col;
funcs::ConcatFunctor<Context, T> concat_functor; funcs::ConcatFunctor<Context, T> concat_functor;
if (dx) { if (dx) {
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_kernel.h" #include "paddle/phi/kernels/conv_transpose_kernel.h"
...@@ -23,6 +22,7 @@ ...@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace phi { namespace phi {
...@@ -139,7 +139,7 @@ void ConvTransposeRawKernel(const Context& ctx, ...@@ -139,7 +139,7 @@ void ConvTransposeRawKernel(const Context& ctx,
paddle::operators::math:: paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T> Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im; col2im;
paddle::operators::math::Col2VolFunctor<Context, T> col2vol; phi::funcs::Col2VolFunctor<Context, T> col2vol;
funcs::ConcatFunctor<Context, T> concat_functor; funcs::ConcatFunctor<Context, T> concat_functor;
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册