Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
12882b2f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
12882b2f
编写于
10月 15, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
10月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ResNetUnit Python API (#35426)
上级
2de0b58e
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
289 addition
and
14 deletion
+289
-14
paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
...ramework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
+6
-3
paddle/fluid/operators/fused/resnet_unit_op.cc
paddle/fluid/operators/fused/resnet_unit_op.cc
+3
-2
paddle/fluid/operators/fused/resnet_unit_op.cu
paddle/fluid/operators/fused/resnet_unit_op.cu
+10
-9
python/paddle/incubate/operators/__init__.py
python/paddle/incubate/operators/__init__.py
+1
-0
python/paddle/incubate/operators/resnet_unit.py
python/paddle/incubate/operators/resnet_unit.py
+269
-0
未找到文件。
paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc
浏览文件 @
12882b2f
...
@@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
...
@@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr
->
GeneratedOp
());
out_var_ptr
->
GeneratedOp
());
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if
(
right_generated_op
->
Name
()
!=
"conv2d_grad"
)
{
if
(
right_generated_op
->
Name
()
!=
"conv2d_grad"
&&
right_generated_op
->
Name
()
!=
"resnet_unit_grad"
)
{
continue
;
continue
;
}
}
...
@@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
...
@@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if
(
node
.
inputs
.
empty
())
return
false
;
if
(
node
.
inputs
.
empty
())
return
false
;
auto
*
generated_op
=
node
.
inputs
[
0
];
auto
*
generated_op
=
node
.
inputs
[
0
];
auto
*
op_desc
=
generated_op
->
Op
();
auto
*
op_desc
=
generated_op
->
Op
();
if
(
op_desc
==
nullptr
||
op_desc
->
Type
()
!=
"conv2d_grad"
)
{
if
(
op_desc
==
nullptr
||
(
op_desc
->
Type
()
!=
"conv2d_grad"
&&
op_desc
->
Type
()
!=
"resnet_unit_grad"
))
{
return
false
;
return
false
;
}
}
const
auto
&
outputs
=
op_desc
->
Outputs
();
const
auto
&
outputs
=
op_desc
->
Outputs
();
auto
iter
=
outputs
.
find
(
GradVarName
(
"Input"
));
std
::
string
grad_var_name
=
op_desc
->
Type
()
==
"conv2d_grad"
?
"Input"
:
"X"
;
auto
iter
=
outputs
.
find
(
GradVarName
(
grad_var_name
));
return
iter
!=
outputs
.
end
()
&&
!
iter
->
second
.
empty
()
&&
return
iter
!=
outputs
.
end
()
&&
!
iter
->
second
.
empty
()
&&
iter
->
second
[
0
]
==
node
.
Name
()
&&
iter
->
second
[
0
]
==
node
.
Name
()
&&
!
op_desc
->
GetAttrIfExists
<
bool
>
(
"use_addto"
);
!
op_desc
->
GetAttrIfExists
<
bool
>
(
"use_addto"
);
...
...
paddle/fluid/operators/fused/resnet_unit_op.cc
浏览文件 @
12882b2f
...
@@ -232,6 +232,7 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -232,6 +232,7 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_addto"
,
""
).
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"act_type"
,
"The activation type to be fused."
)
AddAttr
<
std
::
string
>
(
"act_type"
,
"The activation type to be fused."
)
.
SetDefault
(
"relu"
);
.
SetDefault
(
"relu"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
...
...
paddle/fluid/operators/fused/resnet_unit_op.cu
浏览文件 @
12882b2f
...
@@ -55,7 +55,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
...
@@ -55,7 +55,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
dilat
e
=
ctx
.
Attr
<
int
>
(
"dilate
"
);
int
dilat
ion
=
ctx
.
Attr
<
int
>
(
"dilation
"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
...
@@ -87,7 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
...
@@ -87,7 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_x
.
Resize
(
param_dims
);
sum_x
.
Resize
(
param_dims
);
sum_of_squares_x
.
Resize
(
param_dims
);
sum_of_squares_x
.
Resize
(
param_dims
);
CudnnNormConvolution
<
T
>
conv_x_op
(
dev_ctx
,
input_x_shape
,
filter_x_shape
,
CudnnNormConvolution
<
T
>
conv_x_op
(
dev_ctx
,
input_x_shape
,
filter_x_shape
,
output_shape
,
padding
,
stride
,
dilat
e
,
output_shape
,
padding
,
stride
,
dilat
ion
,
group
);
group
);
conv_x_op
.
Forward
(
dev_ctx
,
*
input_x
,
*
filter_x
,
conv_out_x
,
&
sum_x
,
conv_x_op
.
Forward
(
dev_ctx
,
*
input_x
,
*
filter_x
,
conv_out_x
,
&
sum_x
,
&
sum_of_squares_x
);
&
sum_of_squares_x
);
...
@@ -129,8 +129,8 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
...
@@ -129,8 +129,8 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_z
.
Resize
(
param_dims
);
sum_z
.
Resize
(
param_dims
);
sum_of_squares_z
.
Resize
(
param_dims
);
sum_of_squares_z
.
Resize
(
param_dims
);
CudnnNormConvolution
<
T
>
conv_z_op
(
dev_ctx
,
input_z_shape
,
filter_z_shape
,
CudnnNormConvolution
<
T
>
conv_z_op
(
dev_ctx
,
input_z_shape
,
filter_z_shape
,
output_shape
,
padding
,
stride_z
,
dilate
,
output_shape
,
padding
,
stride_z
,
group
);
dilation
,
group
);
conv_z_op
.
Forward
(
dev_ctx
,
*
input_z
,
*
filter_z
,
conv_out_z
,
&
sum_z
,
conv_z_op
.
Forward
(
dev_ctx
,
*
input_z
,
*
filter_z
,
conv_out_z
,
&
sum_z
,
&
sum_of_squares_z
);
&
sum_of_squares_z
);
...
@@ -189,7 +189,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
...
@@ -189,7 +189,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
padding
=
ctx
.
Attr
<
int
>
(
"padding"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride
=
ctx
.
Attr
<
int
>
(
"stride"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
stride_z
=
ctx
.
Attr
<
int
>
(
"stride_z"
);
int
dilat
e
=
ctx
.
Attr
<
int
>
(
"dilate
"
);
int
dilat
ion
=
ctx
.
Attr
<
int
>
(
"dilation
"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
double
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
...
@@ -263,7 +263,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
...
@@ -263,7 +263,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
auto
filter_z_shape
=
framework
::
vectorize
<
int
>
(
filter_z
->
dims
());
auto
filter_z_shape
=
framework
::
vectorize
<
int
>
(
filter_z
->
dims
());
CudnnNormConvolutionGrad
<
T
>
conv_z_op
(
dev_ctx
,
z_shape
,
filter_z_shape
,
CudnnNormConvolutionGrad
<
T
>
conv_z_op
(
dev_ctx
,
z_shape
,
filter_z_shape
,
output_shape
,
padding
,
stride_z
,
output_shape
,
padding
,
stride_z
,
dilat
e
,
group
);
dilat
ion
,
group
);
conv_z_op
.
Backward
(
dev_ctx
,
*
z
,
*
filter_z
,
conv_out_z_grad
,
z_grad
,
conv_z_op
.
Backward
(
dev_ctx
,
*
z
,
*
filter_z
,
conv_out_z_grad
,
z_grad
,
filter_z_grad
);
filter_z_grad
);
}
else
{
}
else
{
...
@@ -278,11 +278,12 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
...
@@ -278,11 +278,12 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
}
}
// 2. Backward of Conv for x, get x_grad and filter_x_grad
// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool
use_addto
=
ctx
.
Attr
<
bool
>
(
"use_addto"
);
CudnnNormConvolutionGrad
<
T
>
conv_x_op
(
dev_ctx
,
x_shape
,
filter_x_shape
,
CudnnNormConvolutionGrad
<
T
>
conv_x_op
(
dev_ctx
,
x_shape
,
filter_x_shape
,
output_shape
,
padding
,
stride
,
dilate
,
output_shape
,
padding
,
stride
,
group
);
dilation
,
group
);
conv_x_op
.
Backward
(
dev_ctx
,
*
x
,
*
filter_x
,
conv_out_x_grad
,
x_grad
,
conv_x_op
.
Backward
(
dev_ctx
,
*
x
,
*
filter_x
,
conv_out_x_grad
,
x_grad
,
filter_x_grad
);
filter_x_grad
,
use_addto
);
}
}
};
};
...
...
python/paddle/incubate/operators/__init__.py
浏览文件 @
12882b2f
...
@@ -14,3 +14,4 @@
...
@@ -14,3 +14,4 @@
from
.softmax_mask_fuse_upper_triangle
import
softmax_mask_fuse_upper_triangle
# noqa: F401
from
.softmax_mask_fuse_upper_triangle
import
softmax_mask_fuse_upper_triangle
# noqa: F401
from
.softmax_mask_fuse
import
softmax_mask_fuse
# noqa: F401
from
.softmax_mask_fuse
import
softmax_mask_fuse
# noqa: F401
from
.resnet_unit
import
ResNetUnit
#noqa: F401
python/paddle/incubate/operators/resnet_unit.py
0 → 100644
浏览文件 @
12882b2f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
collections
import
itertools
import
six
import
math
import
sys
import
warnings
from
functools
import
partial
,
reduce
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle
import
framework
from
paddle.device
import
get_device
,
get_cudnn_version
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
Layer
,
LayerList
from
paddle.fluid.layers
import
utils
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layers.utils
import
map_structure
,
flatten
,
pack_sequence_as
from
paddle.fluid.data_feeder
import
convert_dtype
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle
import
_C_ops
__all__
=
[
'resnet_unit'
,
'ResNetUnit'
]
def
resnet_unit
(
x
,
filter_x
,
scale_x
,
bias_x
,
mean_x
,
var_x
,
z
,
filter_z
,
scale_z
,
bias_z
,
mean_z
,
var_z
,
stride
,
stride_z
,
padding
,
dilation
,
groups
,
momentum
,
eps
,
data_format
,
fuse_add
,
has_shortcut
,
use_global_stats
,
is_test
,
act
):
helper
=
LayerHelper
(
'resnet_unit'
,
**
locals
())
bn_param_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
bit_mask_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
INT32
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
bit_mask
=
helper
.
create_variable_for_type_inference
(
dtype
=
bit_mask_dtype
,
stop_gradient
=
True
)
# intermediate_out for x
conv_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
,
stop_gradient
=
True
)
saved_mean_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
saved_invstd_x
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
running_mean_x
=
mean_x
running_var_x
=
var_x
# intermediate_out for z
conv_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
,
stop_gradient
=
True
)
saved_mean_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
saved_invstd_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
running_mean_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
if
mean_z
is
None
else
mean_z
running_var_z
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
if
var_z
is
None
else
var_z
inputs
=
{
'X'
:
x
,
'FilterX'
:
filter_x
,
'ScaleX'
:
scale_x
,
'BiasX'
:
bias_x
,
'MeanX'
:
mean_x
,
'VarX'
:
var_x
,
'Z'
:
z
,
'FilterZ'
:
filter_z
,
'ScaleZ'
:
scale_z
,
'BiasZ'
:
bias_z
,
'MeanZ'
:
mean_z
,
'VarZ'
:
var_z
}
attrs
=
{
'stride'
:
stride
,
'stride_z'
:
stride_z
,
'padding'
:
padding
,
'dilation'
:
dilation
,
'group'
:
groups
,
'momentum'
:
momentum
,
'epsilon'
:
eps
,
'data_format'
:
data_format
,
'fuse_add'
:
fuse_add
,
'has_shortcut'
:
has_shortcut
,
'use_global_stats'
:
use_global_stats
,
'is_test'
:
is_test
,
'act_type'
:
act
}
outputs
=
{
'Y'
:
out
,
'BitMask'
:
bit_mask
,
'ConvX'
:
conv_x
,
'SavedMeanX'
:
saved_mean_x
,
'SavedInvstdX'
:
saved_invstd_x
,
'RunningMeanX'
:
running_mean_x
,
'RunningVarX'
:
running_var_x
,
'ConvZ'
:
conv_z
,
'SavedMeanZ'
:
saved_mean_z
,
'SavedInvstdZ'
:
saved_invstd_z
,
'RunningMeanZ'
:
running_mean_z
,
'RunningVarZ'
:
running_var_z
,
}
helper
.
append_op
(
type
=
'resnet_unit'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
return
out
class
ResNetUnit
(
Layer
):
r
"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""
def
__init__
(
self
,
num_channels_x
,
num_filters
,
filter_size
,
stride
=
1
,
momentum
=
0.9
,
eps
=
1e-5
,
data_format
=
'NHWC'
,
act
=
'relu'
,
fuse_add
=
False
,
has_shortcut
=
False
,
use_global_stats
=
False
,
is_test
=
False
,
filter_x_attr
=
None
,
scale_x_attr
=
None
,
bias_x_attr
=
None
,
moving_mean_x_name
=
None
,
moving_var_x_name
=
None
,
num_channels_z
=
1
,
stride_z
=
1
,
filter_z_attr
=
None
,
scale_z_attr
=
None
,
bias_z_attr
=
None
,
moving_mean_z_name
=
None
,
moving_var_z_name
=
None
):
super
(
ResNetUnit
,
self
).
__init__
()
self
.
_stride
=
stride
self
.
_stride_z
=
stride_z
self
.
_dilation
=
1
self
.
_kernel_size
=
utils
.
convert_to_list
(
filter_size
,
2
,
'kernel_size'
)
self
.
_padding
=
(
filter_size
-
1
)
//
2
self
.
_groups
=
1
self
.
_momentum
=
momentum
self
.
_eps
=
eps
self
.
_data_format
=
data_format
self
.
_act
=
act
self
.
_fuse_add
=
fuse_add
self
.
_has_shortcut
=
has_shortcut
self
.
_use_global_stats
=
use_global_stats
self
.
_is_test
=
is_test
# check format
valid_format
=
{
'NHWC'
}
if
data_format
not
in
valid_format
:
raise
ValueError
(
"conv_format must be one of {}, but got conv_format='{}'"
.
format
(
valid_format
,
data_format
))
def
_get_default_param_initializer
(
channels
):
filter_elem_num
=
np
.
prod
(
self
.
_kernel_size
)
*
channels
std
=
(
2.0
/
filter_elem_num
)
**
0.5
return
I
.
Normal
(
0.0
,
std
)
# initial filter
bn_param_dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
bn_param_shape
=
[
1
,
1
,
1
,
num_filters
]
filter_x_shape
=
[
num_filters
,
filter_size
,
filter_size
,
num_channels_x
]
filter_z_shape
=
[
num_filters
,
filter_size
,
filter_size
,
num_channels_z
]
self
.
filter_x
=
self
.
create_parameter
(
shape
=
filter_x_shape
,
attr
=
filter_x_attr
,
default_initializer
=
_get_default_param_initializer
(
num_channels_x
))
self
.
scale_x
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
scale_x_attr
,
dtype
=
bn_param_dtype
,
default_initializer
=
I
.
Constant
(
1.0
))
self
.
bias_x
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
bias_x_attr
,
dtype
=
bn_param_dtype
,
is_bias
=
True
)
self
.
mean_x
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_x_name
,
initializer
=
I
.
Constant
(
0.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
mean_x
.
stop_gradient
=
True
self
.
var_x
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_var_x_name
,
initializer
=
I
.
Constant
(
1.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
var_x
.
stop_gradient
=
True
if
has_shortcut
:
self
.
filter_z
=
self
.
create_parameter
(
shape
=
filter_z_shape
,
attr
=
filter_z_attr
,
default_initializer
=
_get_default_param_initializer
(
num_channels_z
))
self
.
scale_z
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
scale_z_attr
,
dtype
=
bn_param_dtype
,
default_initializer
=
I
.
Constant
(
1.0
))
self
.
bias_z
=
self
.
create_parameter
(
shape
=
bn_param_shape
,
attr
=
bias_z_attr
,
dtype
=
bn_param_dtype
,
is_bias
=
True
)
self
.
mean_z
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_z_name
,
initializer
=
I
.
Constant
(
0.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
mean_z
.
stop_gradient
=
True
self
.
var_z
=
self
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_var_z_name
,
initializer
=
I
.
Constant
(
1.0
),
trainable
=
False
),
shape
=
bn_param_shape
,
dtype
=
bn_param_dtype
)
self
.
var_z
.
stop_gradient
=
True
else
:
self
.
filter_z
=
None
self
.
scale_z
=
None
self
.
bias_z
=
None
self
.
mean_z
=
None
self
.
var_z
=
None
def
forward
(
self
,
x
,
z
=
None
):
if
self
.
_fuse_add
and
z
is
None
:
raise
ValueError
(
"z can not be None"
)
out
=
resnet_unit
(
x
,
self
.
filter_x
,
self
.
scale_x
,
self
.
bias_x
,
self
.
mean_x
,
self
.
var_x
,
z
,
self
.
filter_z
,
self
.
scale_z
,
self
.
bias_z
,
self
.
mean_z
,
self
.
var_z
,
self
.
_stride
,
self
.
_stride_z
,
self
.
_padding
,
self
.
_dilation
,
self
.
_groups
,
self
.
_momentum
,
self
.
_eps
,
self
.
_data_format
,
self
.
_fuse_add
,
self
.
_has_shortcut
,
self
.
_use_global_stats
,
self
.
_is_test
,
self
.
_act
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录