Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
aa36c6aa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aa36c6aa
编写于
11月 22, 2022
作者:
H
huangjiyi
提交者:
GitHub
11月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI decoupling] move vol2col from fluid to phi (#48175)
* move vol2col from fluid to phi * update copyright year
上级
48d5c36b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
62 addition
and
68 deletion
+62
-68
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+0
-1
paddle/fluid/operators/math/vol2col_test.cc
paddle/fluid/operators/math/vol2col_test.cc
+7
-6
paddle/phi/kernels/funcs/CMakeLists.txt
paddle/phi/kernels/funcs/CMakeLists.txt
+1
-0
paddle/phi/kernels/funcs/vol2col.cc
paddle/phi/kernels/funcs/vol2col.cc
+16
-18
paddle/phi/kernels/funcs/vol2col.cu
paddle/phi/kernels/funcs/vol2col.cu
+18
-20
paddle/phi/kernels/funcs/vol2col.h
paddle/phi/kernels/funcs/vol2col.h
+7
-10
paddle/phi/kernels/impl/conv_grad_kernel_impl.h
paddle/phi/kernels/impl/conv_grad_kernel_impl.h
+6
-6
paddle/phi/kernels/impl/conv_kernel_impl.h
paddle/phi/kernels/impl/conv_kernel_impl.h
+2
-2
paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h
paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h
+2
-2
paddle/phi/kernels/impl/conv_transpose_kernel_impl.h
paddle/phi/kernels/impl/conv_transpose_kernel_impl.h
+2
-2
未找到文件。
paddle/fluid/operators/conv_op.h
浏览文件 @
aa36c6aa
...
...
@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
aa36c6aa
...
...
@@ -44,7 +44,6 @@ endif()
math_library
(
matrix_bit_code
)
math_library
(
unpooling
)
math_library
(
vol2col
)
math_library
(
prelu
)
math_library
(
bert_encoder_functor
)
math_library
(
tree2col DEPS math_function
)
...
...
paddle/fluid/operators/math/vol2col_test.cc
浏览文件 @
aa36c6aa
/* Copyright (c) 20
16
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/
fluid/operators/math
/vol2col.h"
#include "paddle/
phi/kernels/funcs
/vol2col.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
...
...
@@ -84,7 +85,7 @@ void testVol2col() {
output_width
},
*
place
);
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
DeviceContext
,
float
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
DeviceContext
,
float
>
vol2col
;
vol2col
(
*
context
,
input
,
dilations
,
strides
,
paddings
,
&
output
);
float
vol_2_col
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
10
,
10
,
11
};
...
...
@@ -110,7 +111,7 @@ void testVol2col() {
paddle
::
framework
::
TensorCopySync
(
input_tmp
,
*
place
,
&
input
);
}
p
addle
::
operators
::
math
::
Col2VolFunctor
<
DeviceContext
,
float
>
col2vol
;
p
hi
::
funcs
::
Col2VolFunctor
<
DeviceContext
,
float
>
col2vol
;
col2vol
(
*
context
,
output
,
dilations
,
strides
,
paddings
,
&
input
);
float
*
in_ptr
;
...
...
@@ -201,7 +202,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
output_width
},
*
place
);
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
phi
::
GPUContext
,
float
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
phi
::
GPUContext
,
float
>
vol2col
;
vol2col
(
*
context
,
input
,
dilations
,
strides
,
paddings
,
&
output
);
float
vol_2_col
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
10
,
10
,
11
};
...
...
@@ -227,7 +228,7 @@ void testVol2col<phi::GPUContext, paddle::platform::CUDAPlace>() {
paddle
::
framework
::
TensorCopySync
(
input_tmp
,
*
place
,
&
input
);
}
p
addle
::
operators
::
math
::
Col2VolFunctor
<
phi
::
GPUContext
,
float
>
col2vol
;
p
hi
::
funcs
::
Col2VolFunctor
<
phi
::
GPUContext
,
float
>
col2vol
;
col2vol
(
*
context
,
output
,
dilations
,
strides
,
paddings
,
&
input
);
float
*
in_ptr
;
...
...
paddle/phi/kernels/funcs/CMakeLists.txt
浏览文件 @
aa36c6aa
...
...
@@ -17,6 +17,7 @@ math_library(segment_pooling)
math_library
(
sequence2batch
)
math_library
(
matrix_solve DEPS dense_tensor eigen3 blas math_function
)
math_library
(
cross_entropy
)
math_library
(
vol2col
)
cc_library
(
phi_data_layout_transform
...
...
paddle/
fluid/operators/math
/vol2col.cc
→
paddle/
phi/kernels/funcs
/vol2col.cc
浏览文件 @
aa36c6aa
/* Copyright (c) 20
16
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/
fluid/operators/math
/vol2col.h"
#include "paddle/
phi/kernels/funcs
/vol2col.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
phi
{
namespace
funcs
{
/*
* vol = [input_channels, input_depth, input_height, input_width]
...
...
@@ -38,13 +37,13 @@ class Vol2ColFunctor<phi::CPUContext, T> {
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
vol
.
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
col
->
dims
().
size
()));
...
...
@@ -81,7 +80,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -92,7 +91,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
...
...
@@ -103,7 +102,7 @@ class Vol2ColFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -164,13 +163,13 @@ class Col2VolFunctor<phi::CPUContext, T> {
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
vol
->
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
col
.
dims
().
size
()));
...
...
@@ -206,7 +205,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -217,7 +216,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
...
...
@@ -228,7 +227,7 @@ class Col2VolFunctor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -278,6 +277,5 @@ template class Vol2ColFunctor<phi::CPUContext, double>;
template
class
Col2VolFunctor
<
phi
::
CPUContext
,
float
>;
template
class
Col2VolFunctor
<
phi
::
CPUContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
}
// namespace funcs
}
// namespace phi
paddle/
fluid/operators/math
/vol2col.cu
→
paddle/
phi/kernels/funcs
/vol2col.cu
浏览文件 @
aa36c6aa
...
...
@@ -15,14 +15,13 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
phi
{
namespace
funcs
{
template
<
class
T
>
__global__
void
vol2col
(
int
num_kernels
,
...
...
@@ -112,12 +111,12 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
vol
.
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
col
->
dims
().
size
()));
...
...
@@ -149,7 +148,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1
;
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -160,7 +159,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
...
...
@@ -170,7 +169,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -180,7 +179,7 @@ void Vol2ColFunctor<DeviceContext, T>::operator()(
int
max_threads
=
1024
;
#ifdef WITH_NV_JETSON
p
latform
::
ChangeThreadNum
(
context
,
&
max_threads
);
p
hi
::
backends
::
gpu
::
ChangeThreadNum
(
context
,
&
max_threads
);
#endif
const
int
threads
=
max_threads
;
...
...
@@ -318,12 +317,12 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
vol
->
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
col
.
dims
().
size
()));
...
...
@@ -356,7 +355,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1
;
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
...
...
@@ -367,7 +366,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
...
...
@@ -377,7 +376,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
...
...
@@ -386,7 +385,7 @@ void Col2VolFunctor<DeviceContext, T>::operator()(
int
max_threads
=
1024
;
#ifdef WITH_NV_JETSON
p
latform
::
ChangeThreadNum
(
context
,
&
max_threads
);
p
hi
::
backends
::
gpu
::
ChangeThreadNum
(
context
,
&
max_threads
);
#endif
const
int
threads
=
max_threads
;
...
...
@@ -423,6 +422,5 @@ template class Vol2ColFunctor<phi::GPUContext, double>;
template
class
Col2VolFunctor
<
phi
::
GPUContext
,
float
>;
template
class
Col2VolFunctor
<
phi
::
GPUContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
}
// namespace funcs
}
// namespace phi
paddle/
fluid/operators/math
/vol2col.h
→
paddle/
phi/kernels/funcs
/vol2col.h
浏览文件 @
aa36c6aa
/* Copyright (c) 20
16
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -16,13 +16,11 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
phi
{
namespace
funcs
{
using
DataLayout
=
phi
::
DataLayout
;
...
...
@@ -92,6 +90,5 @@ class Col2VolFunctor {
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
)
const
;
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/impl/conv_grad_kernel_impl.h
浏览文件 @
aa36c6aa
...
...
@@ -15,11 +15,11 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
phi
{
...
...
@@ -147,7 +147,7 @@ void ConvGradKernel(const Context& dev_ctx,
if
(
is_expand
)
{
set_zero
(
dev_ctx
,
&
transformed_input_grad
,
static_cast
<
T
>
(
0
));
}
p
addle
::
operators
::
math
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
p
hi
::
funcs
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
paddle
::
operators
::
math
::
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
col2im
;
...
...
@@ -206,7 +206,7 @@ void ConvGradKernel(const Context& dev_ctx,
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
im2col
;
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
DenseTensor
out_grad_batch
=
transformed_output_grad
.
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
@@ -381,7 +381,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
if
(
is_expand
)
{
set_zero
(
dev_ctx
,
&
transformed_dX
,
static_cast
<
T
>
(
0
));
}
p
addle
::
operators
::
math
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
p
hi
::
funcs
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
paddle
::
operators
::
math
::
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
col2im
;
...
...
@@ -431,7 +431,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
im2col
;
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
DenseTensor
dy_batch
=
transformed_dY
.
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
@@ -480,7 +480,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
im2col
;
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
DenseTensor
ddy_batch
=
transformed_ddY
.
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
paddle/phi/kernels/impl/conv_kernel_impl.h
浏览文件 @
aa36c6aa
...
...
@@ -15,12 +15,12 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
phi
{
...
...
@@ -133,7 +133,7 @@ void ConvKernelImpl(const Context& dev_ctx,
int
in_step
=
static_cast
<
int
>
(
transformed_input
.
dims
()[
1
])
/
groups
;
int
out_step
=
static_cast
<
int
>
(
transformed_output
.
dims
()[
1
])
/
groups
;
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
im2col
;
...
...
paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h
浏览文件 @
aa36c6aa
...
...
@@ -15,7 +15,6 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_grad_kernel.h"
...
...
@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
phi
{
...
...
@@ -146,7 +146,7 @@ void ConvTransposeGradRawKernel(const Context& ctx,
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
im2col
;
p
addle
::
operators
::
math
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
p
hi
::
funcs
::
Vol2ColFunctor
<
Context
,
T
>
vol2col
;
funcs
::
ConcatFunctor
<
Context
,
T
>
concat_functor
;
if
(
dx
)
{
...
...
paddle/phi/kernels/impl/conv_transpose_kernel_impl.h
浏览文件 @
aa36c6aa
...
...
@@ -15,7 +15,6 @@
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/conv_transpose_kernel.h"
...
...
@@ -23,6 +22,7 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace
phi
{
...
...
@@ -139,7 +139,7 @@ void ConvTransposeRawKernel(const Context& ctx,
paddle
::
operators
::
math
::
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Context
,
T
>
col2im
;
p
addle
::
operators
::
math
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
p
hi
::
funcs
::
Col2VolFunctor
<
Context
,
T
>
col2vol
;
funcs
::
ConcatFunctor
<
Context
,
T
>
concat_functor
;
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录