Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b76343c3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b76343c3
编写于
9月 17, 2019
作者:
L
lvmengsi
提交者:
GitHub
9月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cpu Conv double grad (#19672)
* cpu conv_grad_grad
上级
754fd57e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
399 addition
and
27 deletion
+399
-27
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+5
-0
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+52
-3
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+213
-0
python/paddle/fluid/tests/unittests/test_conv_nn_grad.py
python/paddle/fluid/tests/unittests/test_conv_nn_grad.py
+129
-0
python/paddle/fluid/tests/unittests/test_nn_grad.py
python/paddle/fluid/tests/unittests/test_nn_grad.py
+0
-24
未找到文件。
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
b76343c3
...
...
@@ -510,3 +510,8 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL
(
conv3d_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
double
>
);
REGISTER_OP_KERNEL
(
conv3d_grad_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNConvDoubleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNConvDoubleGradOpKernel
<
double
>
,
paddle
::
operators
::
CUDNNConvDoubleGradOpKernel
<
plat
::
float16
>
);
paddle/fluid/operators/conv_op.cc
浏览文件 @
b76343c3
...
...
@@ -565,6 +565,40 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
class
Conv3DDoubleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
// I, W, dO, ddI, ddW
op
->
SetInput
(
"Input"
,
Input
(
"Input"
));
op
->
SetInput
(
"Filter"
,
Input
(
"Filter"
));
op
->
SetInput
(
"DOutput"
,
Input
(
framework
::
GradVarName
(
"Output"
)));
op
->
SetInput
(
"DDInput"
,
OutputGrad
(
framework
::
GradVarName
(
"Input"
)));
op
->
SetInput
(
"DDFilter"
,
OutputGrad
(
framework
::
GradVarName
(
"Filter"
)));
auto
ddx
=
OutputGrad
(
framework
::
GradVarName
(
"Input"
));
auto
ddw
=
OutputGrad
(
framework
::
GradVarName
(
"Filter"
));
std
::
vector
<
std
::
string
>
empty_str
=
{};
op
->
SetOutput
(
"DDOutput"
,
ddx
.
empty
()
?
empty_str
:
InputGrad
(
framework
::
GradVarName
(
"Output"
)));
op
->
SetOutput
(
"DFilter"
,
ddx
.
empty
()
?
empty_str
:
InputGrad
(
"Filter"
));
op
->
SetOutput
(
"DInput"
,
ddw
.
empty
()
?
empty_str
:
InputGrad
(
"Input"
));
op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
void
ConvOpDoubleGrad
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
auto
x_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Filter"
);
...
...
@@ -592,8 +626,14 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
else
{
PADDLE_THROW
(
"Now ConvDoubleGrad only supports cuDNN."
);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
}
#endif
auto
type
=
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
...
...
@@ -637,7 +677,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR
(
conv3d
,
ops
::
ConvOp
,
ops
::
Conv3DOpMaker
,
ops
::
ConvOpInferVarType
,
ops
::
Conv3DGradMaker
);
REGISTER_OPERATOR
(
conv3d_grad
,
ops
::
ConvOpGrad
);
REGISTER_OPERATOR
(
conv3d_grad
,
ops
::
ConvOpGrad
,
ops
::
Conv3DDoubleGradMaker
);
REGISTER_OPERATOR
(
conv3d_grad_grad
,
ops
::
ConvOpDoubleGrad
);
// depthwise conv kernel
// TODO(xingzhaolong): neon kernel for mobile
...
...
@@ -658,6 +699,10 @@ REGISTER_OP_CPU_KERNEL(
conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv2d_grad_grad
,
ops
::
GemmConvDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
@@ -666,3 +711,7 @@ REGISTER_OP_CPU_KERNEL(
conv3d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d_grad_grad
,
ops
::
GemmConvDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/conv_op.h
浏览文件 @
b76343c3
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
...
...
@@ -393,6 +394,218 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GemmConvDoubleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
"It must use CPUPlace."
);
const
Tensor
*
X
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
Tensor
*
dY
=
ctx
.
Input
<
Tensor
>
(
"DOutput"
);
const
Tensor
*
ddX
=
ctx
.
Input
<
Tensor
>
(
"DDInput"
);
const
Tensor
*
ddW_in
=
ctx
.
Input
<
Tensor
>
(
"DDFilter"
);
Tensor
*
ddY
=
ctx
.
Output
<
Tensor
>
(
"DDOutput"
);
Tensor
*
dW
=
ctx
.
Output
<
Tensor
>
(
"DFilter"
);
Tensor
*
dX
=
ctx
.
Output
<
Tensor
>
(
"DInput"
);
Tensor
W
=
detail
::
Ref
(
ctx
.
Input
<
Tensor
>
(
"Filter"
),
"Cannot find input Filter(%s) in scope)"
,
ctx
.
Inputs
(
"Filter"
)[
0
]);
if
(
!
ddY
&&
!
dW
&&
!
dX
)
return
;
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
X
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
W
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
dY
->
dims
()));
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
// col_shape [in_channel/group, kh, kw, oh, ow]
col_shape_vec
[
0
]
=
X
->
dims
()[
1
]
/
groups
;
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
];
col_shape_vec
[
j
+
data_dim
+
1
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// col_matrix_shape [in_channel/group * kh * kw, oh * ow]
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
// input_shape [Cin, H, W]
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
X
->
dims
(),
1
,
X
->
dims
().
size
());
// filter_matrix_shape [Cout, Cin * kh * kw]
framework
::
DDim
filter_matrix_shape
=
{
W
.
dims
()[
0
],
W
.
numel
()
/
W
.
dims
()[
0
]};
W
.
Resize
(
filter_matrix_shape
);
framework
::
DDim
output_matrix_shape
=
{
dY
->
dims
()[
1
],
dY
->
numel
()
/
(
dY
->
dims
()[
0
]
*
dY
->
dims
()[
1
])};
int
in_step
=
static_cast
<
int
>
(
X
->
dims
()[
1
])
/
groups
;
int
out_step
=
static_cast
<
int
>
(
dY
->
dims
()[
1
])
/
groups
;
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
Tensor
col_matrix
;
if
(
is_expand
)
{
col
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
col_shape
,
dev_ctx
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
// dx convolution double grad: gemm + col2im(col2vol)
// dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout,
// oH, oW)
if
(
dX
&&
ddW_in
)
{
Tensor
ddW
;
ddW
.
ShareDataWith
(
*
ddW_in
).
Resize
(
filter_matrix_shape
);
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// if is_expand is false, the operation of set_zero is unnecessary
// because math::matmul will reset dx
if
(
is_expand
)
{
set_zero
(
dev_ctx
,
dX
,
static_cast
<
T
>
(
0
));
}
math
::
Col2VolFunctor
<
DeviceContext
,
T
>
col2vol
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
col2im
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
dy_batch
=
dY
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
Tensor
dx_batch
=
dX
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
// gemm
Tensor
dy_slice
=
dy_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
ddw_slice
=
ddW
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
dx_slice
=
dx_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
!
is_expand
)
{
col_matrix
.
ShareDataWith
(
dx_slice
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
blas
.
MatMul
(
ddw_slice
,
true
,
dy_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
if
(
is_expand
&&
data_dim
==
2U
)
{
col2im
(
dev_ctx
,
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
dx_slice
);
}
else
if
(
is_expand
&&
data_dim
==
3U
)
{
col2vol
(
dev_ctx
,
col
,
dilations
,
strides
,
paddings
,
&
dx_slice
);
}
}
}
}
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if
(
dW
)
{
dW
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
dW
,
static_cast
<
T
>
(
0
));
Tensor
dW_arr
=
*
dW
;
dW_arr
.
Resize
(
filter_matrix_shape
);
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
im2col
;
math
::
Vol2ColFunctor
<
DeviceContext
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
Tensor
dy_batch
=
dY
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
Tensor
ddx_batch
=
ddX
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
// im2col
Tensor
dy_slice
=
dy_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
ddx_slice
=
ddx_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
!
is_expand
)
{
col
.
ShareDataWith
(
ddx_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
data_dim
==
2U
)
{
im2col
(
dev_ctx
,
ddx_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
data_dim
==
3U
)
{
vol2col
(
dev_ctx
,
ddx_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
Tensor
dw_slice
=
dW_arr
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
blas
.
MatMul
(
dy_slice
,
false
,
col_matrix
,
true
,
T
(
1.0
),
&
dw_slice
,
T
(
1.0
));
}
}
}
// ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W),
// w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm
if
(
ddY
)
{
ddY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
ddY
,
static_cast
<
T
>
(
0
));
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
im2col
;
math
::
Vol2ColFunctor
<
DeviceContext
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
Tensor
ddx_batch
=
ddX
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
x_batch
=
X
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
ddy_batch
=
ddY
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
x_slice
=
x_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
Tensor
ddx_slice
=
ddx_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
!
is_expand
)
{
col
.
ShareDataWith
(
ddx_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
data_dim
==
2U
)
{
// im2col
im2col
(
dev_ctx
,
ddx_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
data_dim
==
3U
)
{
// vol2col
vol2col
(
dev_ctx
,
ddx_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
// gemm
Tensor
ddy_slice
=
ddy_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
w_slice
=
W
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
blas
.
MatMul
(
w_slice
,
false
,
col_matrix
,
false
,
T
(
1.0
),
&
ddy_slice
,
T
(
0.0
));
if
(
ddW_in
)
{
Tensor
ddW
;
ddW
.
ShareDataWith
(
*
ddW_in
).
Resize
(
filter_matrix_shape
);
if
(
!
is_expand
)
{
col
.
ShareDataWith
(
x_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
data_dim
==
2U
)
{
// im2col
im2col
(
dev_ctx
,
x_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
data_dim
==
3U
)
{
// vol2col
vol2col
(
dev_ctx
,
x_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
// gemm
Tensor
ddw_slice
=
ddW
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
blas
.
MatMul
(
ddw_slice
,
false
,
col_matrix
,
false
,
T
(
1.0
),
&
ddy_slice
,
T
(
1.0
));
}
}
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
python/paddle/fluid/tests/unittests/test_conv_nn_grad.py
0 → 100644
浏览文件 @
b76343c3
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
import
gradient_checker
from
decorator_helper
import
prog_scope
class
TestConvDoubleGradCheck
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
shape
=
[
2
,
4
,
7
,
8
]
eps
=
0.005
dtype
=
np
.
float64
x
=
layers
.
data
(
'x'
,
shape
,
False
,
dtype
)
y
=
layers
.
conv2d
(
x
,
4
,
1
,
bias_attr
=
False
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
w
=
fluid
.
default_main_program
().
global_block
().
all_parameters
()
w_arr
=
[]
for
p
in
w
:
w_arr
.
append
(
np
.
random
.
uniform
(
-
1
,
1
,
p
.
shape
).
astype
(
dtype
))
gradient_checker
.
double_grad_check
(
[
x
]
+
w
,
y
,
x_init
=
[
x_arr
]
+
w_arr
,
place
=
place
,
eps
=
eps
)
def
test_grad
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
func
(
p
)
class
TestConvDoubleGradCheckTest1
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
shape
=
[
2
,
3
,
4
,
5
]
eps
=
0.005
dtype
=
np
.
float64
x
=
layers
.
data
(
'x'
,
shape
,
False
,
dtype
)
y
=
layers
.
conv2d
(
x
,
4
,
1
,
padding
=
1
,
bias_attr
=
False
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
w
=
fluid
.
default_main_program
().
global_block
().
all_parameters
()
w_arr
=
[]
for
p
in
w
:
w_arr
.
append
(
np
.
random
.
uniform
(
-
1
,
1
,
p
.
shape
).
astype
(
dtype
))
gradient_checker
.
double_grad_check
(
[
x
]
+
w
,
y
,
x_init
=
[
x_arr
]
+
w_arr
,
place
=
place
,
eps
=
eps
)
def
test_grad
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
func
(
p
)
class
TestConv3DDoubleGradCheck
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
shape
=
[
2
,
4
,
3
,
4
,
2
]
eps
=
0.005
dtype
=
np
.
float64
x
=
layers
.
data
(
'x'
,
shape
,
False
,
dtype
)
y
=
layers
.
conv3d
(
x
,
4
,
1
,
bias_attr
=
False
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
w
=
fluid
.
default_main_program
().
global_block
().
all_parameters
()
w_arr
=
[]
for
p
in
w
:
w_arr
.
append
(
np
.
random
.
uniform
(
-
1
,
1
,
p
.
shape
).
astype
(
dtype
))
gradient_checker
.
double_grad_check
(
[
x
]
+
w
,
y
,
x_init
=
[
x_arr
]
+
w_arr
,
place
=
place
,
eps
=
eps
)
def
test_grad
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
func
(
p
)
class
TestConv3DDoubleGradCheckTest1
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
shape
=
[
2
,
4
,
5
,
3
,
2
]
eps
=
0.005
dtype
=
np
.
float64
x
=
layers
.
data
(
'x'
,
shape
,
False
,
dtype
)
y
=
layers
.
conv3d
(
x
,
4
,
1
,
padding
=
1
,
bias_attr
=
False
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
w
=
fluid
.
default_main_program
().
global_block
().
all_parameters
()
w_arr
=
[]
for
p
in
w
:
w_arr
.
append
(
np
.
random
.
uniform
(
-
1
,
1
,
p
.
shape
).
astype
(
dtype
))
gradient_checker
.
double_grad_check
(
[
x
]
+
w
,
y
,
x_init
=
[
x_arr
]
+
w_arr
,
place
=
place
,
eps
=
eps
)
def
test_grad
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
func
(
p
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_nn_grad.py
浏览文件 @
b76343c3
...
...
@@ -43,30 +43,6 @@ class TestMulGradCheck(unittest.TestCase):
self
.
func
(
p
)
class
TestConvDoubleGradCheck
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
shape
=
[
2
,
4
,
14
,
16
]
eps
=
0.005
dtype
=
np
.
float64
x
=
layers
.
data
(
'x'
,
shape
,
False
,
dtype
)
y
=
layers
.
conv2d
(
x
,
4
,
1
,
bias_attr
=
False
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
w
=
fluid
.
default_main_program
().
global_block
().
all_parameters
()
w_arr
=
[]
for
p
in
w
:
w_arr
.
append
(
np
.
random
.
uniform
(
-
1
,
1
,
p
.
shape
).
astype
(
dtype
))
gradient_checker
.
double_grad_check
(
[
x
]
+
w
,
y
,
x_init
=
[
x_arr
]
+
w_arr
,
place
=
place
,
eps
=
eps
)
def
test_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
places
=
[
fluid
.
CUDAPlace
(
0
)]
for
p
in
places
:
self
.
func
(
p
)
class
TestReduceMeanWithDimDoubleGradCheck
(
unittest
.
TestCase
):
@
prog_scope
()
def
func
(
self
,
place
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录