Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
56ae33b6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56ae33b6
编写于
6月 01, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
6月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add yaml and unittest for instance_norm op (#43060)
* add yaml * fix infrt compile bugs
上级
b23914c2
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
134 addition
and
12 deletion
+134
-12
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
...uto_code_generator/final_state_generator/codegen_utils.py
+2
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+1
-1
paddle/phi/infermeta/backward.cc
paddle/phi/infermeta/backward.cc
+1
-1
paddle/phi/infermeta/backward.h
paddle/phi/infermeta/backward.h
+1
-1
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
+1
-1
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
+1
-1
paddle/phi/kernels/instance_norm_grad_kernel.h
paddle/phi/kernels/instance_norm_grad_kernel.h
+1
-1
paddle/phi/ops/compat/instance_norm_sig.cc
paddle/phi/ops/compat/instance_norm_sig.cc
+1
-1
python/paddle/fluid/dygraph/nn.py
python/paddle/fluid/dygraph/nn.py
+5
-1
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
+9
-0
python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py
.../paddle/fluid/tests/unittests/test_instance_norm_op_v2.py
+6
-0
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
+66
-0
python/paddle/nn/functional/norm.py
python/paddle/nn/functional/norm.py
+4
-2
python/paddle/utils/code_gen/api.yaml
python/paddle/utils/code_gen/api.yaml
+11
-0
python/paddle/utils/code_gen/backward.yaml
python/paddle/utils/code_gen/backward.yaml
+23
-0
tools/infrt/skipped_phi_api.json
tools/infrt/skipped_phi_api.json
+1
-1
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
浏览文件 @
56ae33b6
...
...
@@ -31,7 +31,8 @@ ops_to_fill_zero_for_empty_grads = set([
"leaky_relu_double_grad"
,
"sqrt_double_grad"
,
"rsqrt_double_grad"
,
"square_double_grad"
,
"celu_double_grad"
,
"pad_double_grad"
,
"pad3d_double_grad"
,
"squeeze_double_grad"
,
"unsqueeze_double_grad"
,
"conv3d_double_grad"
,
"depthwise_conv2d_grad_grad"
"instance_norm_double_grad"
,
"conv3d_double_grad"
,
"depthwise_conv2d_grad_grad"
])
# For API dispatch used at python-level
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
56ae33b6
...
...
@@ -1404,7 +1404,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
const auto& out_metas = OutputMeta();
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns(
{
slot_num_bwd_outputs
}
);
for (int i = 0; i <
{
slot_num_bwd_outputs
}
; ++i) {{
returns[i].resize(out_metas[i].size());
out_metas[i].size() == 0 ? returns[i].resize(1) :
returns[i].resize(out_metas[i].size());
}}
"""
...
...
paddle/phi/infermeta/backward.cc
浏览文件 @
56ae33b6
...
...
@@ -313,10 +313,10 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
}
void
InstanceNormGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y_grad
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
const
MetaTensor
&
y_grad
,
float
epsilon
,
MetaTensor
*
x_grad
,
MetaTensor
*
scale_grad
,
...
...
paddle/phi/infermeta/backward.h
浏览文件 @
56ae33b6
...
...
@@ -145,10 +145,10 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
MetaTensor
*
dx
);
void
InstanceNormGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y_grad
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
const
MetaTensor
&
y_grad
,
float
epsilon
,
MetaTensor
*
x_grad
,
MetaTensor
*
scale_grad
,
...
...
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
浏览文件 @
56ae33b6
...
...
@@ -42,10 +42,10 @@ using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template
<
typename
T
,
typename
Context
>
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
d_y
,
const
paddle
::
optional
<
DenseTensor
>&
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
d_y
,
float
epsilon
,
DenseTensor
*
d_x
,
DenseTensor
*
d_scale
,
...
...
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
浏览文件 @
56ae33b6
...
...
@@ -290,10 +290,10 @@ __global__ void DoubleGradComputeDScale(const T *x,
template
<
typename
T
,
typename
Context
>
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
d_y
,
const
paddle
::
optional
<
DenseTensor
>
&
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
d_y
,
float
epsilon_f
,
DenseTensor
*
d_x
,
DenseTensor
*
d_scale
,
...
...
paddle/phi/kernels/instance_norm_grad_kernel.h
浏览文件 @
56ae33b6
...
...
@@ -21,10 +21,10 @@ namespace phi {
template
<
typename
T
,
typename
Context
>
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y_grad
,
const
paddle
::
optional
<
DenseTensor
>&
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
y_grad
,
float
epsilon
,
DenseTensor
*
x_grad
,
DenseTensor
*
scale_grad
,
...
...
paddle/phi/ops/compat/instance_norm_sig.cc
浏览文件 @
56ae33b6
...
...
@@ -27,7 +27,7 @@ KernelSignature InstanceNormOpArgumentMapping(
KernelSignature
InstanceNormGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"instance_norm_grad"
,
{
"X"
,
"
Y@GRAD"
,
"Scale"
,
"SavedMean"
,
"SavedVariance
"
},
{
"X"
,
"
Scale"
,
"SavedMean"
,
"SavedVariance"
,
"Y@GRAD
"
},
{
"epsilon"
},
{
"X@GRAD"
,
"Scale@GRAD"
,
"Bias@GRAD"
});
}
...
...
python/paddle/fluid/dygraph/nn.py
浏览文件 @
56ae33b6
...
...
@@ -1137,7 +1137,11 @@ class InstanceNorm(layers.Layer):
self
.
bias
=
None
def
forward
(
self
,
input
):
if
_non_static_mode
():
if
in_dygraph_mode
():
out
,
_
,
_
,
=
_C_ops
.
final_state_instance_norm
(
input
,
self
.
scale
,
self
.
bias
,
self
.
_epsilon
)
return
out
if
_in_legacy_dygraph
():
out
,
_
,
_
=
_C_ops
.
instance_norm
(
input
,
self
.
scale
,
self
.
bias
,
'epsilon'
,
self
.
_epsilon
)
return
out
...
...
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
浏览文件 @
56ae33b6
...
...
@@ -22,6 +22,7 @@ from paddle.fluid.op import Operator
from
op_test
import
OpTest
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid.dygraph
import
to_variable
from
paddle.fluid.framework
import
_test_eager_guard
def
_reference_instance_norm_naive
(
x
,
scale
,
bias
,
epsilon
,
mean
,
var
):
...
...
@@ -253,6 +254,10 @@ class TestElasticNormOp(unittest.TestCase):
outputs
=
instance_norm
(
to_variable
(
inputs
))
self
.
assertTrue
(
np
.
allclose
(
outputs
.
numpy
(),
out_np
,
atol
=
1e-6
))
def
test_eager_api
(
self
):
with
_test_eager_guard
():
self
.
test_norm
()
class
TestElasticNormOpCase2
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
...
...
@@ -282,6 +287,10 @@ class TestElasticNormOpCase2(unittest.TestCase):
outputs
=
instance_norm
(
to_variable
(
inputs
))
self
.
assertTrue
(
np
.
allclose
(
outputs
.
numpy
(),
out_np
,
atol
=
1e-6
))
def
test_eager_api
(
self
):
with
_test_eager_guard
():
self
.
test_norm
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py
浏览文件 @
56ae33b6
...
...
@@ -22,6 +22,7 @@ from op_test import OpTest, _set_use_system_allocator
from
paddle.fluid.framework
import
grad_var_name
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid.framework
import
_test_eager_guard
import
paddle
...
...
@@ -116,6 +117,11 @@ class TestInstanceNorm(unittest.TestCase):
y2
=
compute_v2
(
x
)
self
.
assertTrue
(
np
.
allclose
(
y1
,
y2
))
def
test_eager_api
(
self
):
with
_test_eager_guard
():
self
.
test_dygraph
()
self
.
test_error
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
浏览文件 @
56ae33b6
...
...
@@ -70,6 +70,72 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias(
[
x
],
z
,
x_init
=
x_arr
,
atol
=
atol
,
place
=
place
,
eps
=
eps
)
class
TestInstanceNormDoubleGradEagerCheck
(
unittest
.
TestCase
):
def
instance_norm_wrapper
(
self
,
x
):
return
paddle
.
nn
.
functional
.
instance_norm
(
x
[
0
])
@
prog_scope
()
def
func
(
self
,
place
):
prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
prog
):
np
.
random
.
seed
()
shape
=
[
2
,
3
,
4
,
5
]
dtype
=
"float32"
eps
=
0.005
atol
=
1e-4
x
=
layers
.
create_parameter
(
dtype
=
dtype
,
shape
=
shape
,
name
=
'x'
)
z
=
paddle
.
nn
.
functional
.
instance_norm
(
x
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
# check for static mode
gradient_checker
.
double_grad_check
(
[
x
],
z
,
x_init
=
x_arr
,
atol
=
atol
,
place
=
place
,
eps
=
eps
)
# check for eager mode
gradient_checker
.
double_grad_check_for_dygraph
(
self
.
instance_norm_wrapper
,
[
x
],
z
,
x_init
=
x_arr
,
atol
=
atol
,
place
=
place
)
def
test_grad
(
self
):
paddle
.
enable_static
()
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
func
(
p
)
class
TestInstanceNormDoubleGradEagerCheckWithParams
(
TestInstanceNormDoubleGradEagerCheck
):
def
instance_norm_wrapper
(
self
,
x
):
instance_norm
=
paddle
.
nn
.
InstanceNorm2D
(
3
)
return
instance_norm
(
x
[
0
])
@
prog_scope
()
def
func
(
self
,
place
):
prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
prog
):
np
.
random
.
seed
()
shape
=
[
2
,
3
,
4
,
5
]
dtype
=
"float32"
eps
=
0.005
atol
=
1e-4
x
=
layers
.
create_parameter
(
dtype
=
dtype
,
shape
=
shape
,
name
=
'x'
)
z
=
paddle
.
nn
.
InstanceNorm2D
(
3
)(
x
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
dtype
)
# check for static mode
gradient_checker
.
double_grad_check
(
[
x
],
z
,
x_init
=
x_arr
,
atol
=
atol
,
place
=
place
,
eps
=
eps
)
# check for eager mode
gradient_checker
.
double_grad_check_for_dygraph
(
self
.
instance_norm_wrapper
,
[
x
],
z
,
x_init
=
x_arr
,
atol
=
atol
,
place
=
place
)
class
TestBatchNormDoubleGradCheck
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_test
()
...
...
python/paddle/nn/functional/norm.py
浏览文件 @
56ae33b6
...
...
@@ -407,8 +407,10 @@ def instance_norm(x,
print(instance_norm_out)
"""
if
in_dynamic_mode
():
if
in_dygraph_mode
():
out
,
_
,
_
,
=
_C_ops
.
final_state_instance_norm
(
x
,
weight
,
bias
,
eps
)
return
out
if
_in_legacy_dygraph
():
out
,
_
,
_
=
_C_ops
.
instance_norm
(
x
,
weight
,
bias
,
"epsilon"
,
eps
,
"momentum"
,
momentum
,
"data_format"
,
data_format
)
...
...
python/paddle/utils/code_gen/api.yaml
浏览文件 @
56ae33b6
...
...
@@ -1030,6 +1030,17 @@
data_type
:
x
backward
:
index_select_grad
-
api
:
instance_norm
args
:
(Tensor x, Tensor scale, Tensor bias, float epsilon)
output
:
Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
infer_meta
:
func
:
InstanceNormInferMeta
kernel
:
func
:
instance_norm
data_type
:
x
optional
:
scale, bias
backward
:
instance_norm_grad
# is_empty
-
api
:
is_empty
args
:
(Tensor x)
...
...
python/paddle/utils/code_gen/backward.yaml
浏览文件 @
56ae33b6
...
...
@@ -927,6 +927,29 @@
data_type
:
x
no_need_buffer
:
x
-
backward_api
:
instance_norm_double_grad
forward
:
instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args
:
(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon)
output
:
Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad)
infer_meta
:
func
:
InstanceNormDoubleGradInferMeta
kernel
:
func
:
instance_norm_double_grad
data_type
:
x
optional
:
fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad
-
backward_api
:
instance_norm_grad
forward
:
instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
args
:
(Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon)
output
:
Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta
:
func
:
InstanceNormGradInferMeta
kernel
:
func
:
instance_norm_grad
data_type
:
x
optional
:
scale
backward
:
instance_norm_double_grad
-
backward_api
:
kldiv_loss_grad
forward
:
kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out)
args
:
(Tensor x, Tensor label, Tensor out_grad, str reduction)
...
...
tools/infrt/skipped_phi_api.json
浏览文件 @
56ae33b6
{
"phi_apis"
:[
"conj"
,
"deformable_conv"
,
"dropout"
,
"expand_as"
,
"nll_loss"
,
"psroi_pool"
,
"roi_align"
,
"roi_pool"
,
"label_smooth"
,
"layer_norm"
],
"phi_apis"
:[
"conj"
,
"deformable_conv"
,
"dropout"
,
"expand_as"
,
"nll_loss"
,
"psroi_pool"
,
"roi_align"
,
"roi_pool"
,
"label_smooth"
,
"layer_norm"
,
"instance_norm"
],
"phi_kernels"
:[
"equal_all"
]
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录