Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1aa2bde0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1aa2bde0
编写于
8月 20, 2021
作者:
S
shangliang Xu
提交者:
GitHub
8月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[bug fix] fix spectral_norm bug (#35005)
上级
096b0f2e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
10 addition
and
0 deletion
+10
-0
python/paddle/fluid/dygraph/nn.py
python/paddle/fluid/dygraph/nn.py
+6
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+4
-0
未找到文件。
python/paddle/fluid/dygraph/nn.py
浏览文件 @
1aa2bde0
...
@@ -3062,6 +3062,12 @@ class SpectralNorm(layers.Layer):
...
@@ -3062,6 +3062,12 @@ class SpectralNorm(layers.Layer):
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
self
.
_weight_shape
=
list
(
weight_shape
)
self
.
_weight_shape
=
list
(
weight_shape
)
assert
np
.
prod
(
self
.
_weight_shape
)
>
0
,
\
"Any dimension of `weight_shape` cannot be equal to 0."
assert
dim
<
len
(
self
.
_weight_shape
),
\
(
"The input `dim` should be less than the "
"length of `weight_shape`, but received dim="
"{}"
.
format
(
dim
))
h
=
self
.
_weight_shape
[
self
.
_dim
]
h
=
self
.
_weight_shape
[
self
.
_dim
]
w
=
np
.
prod
(
self
.
_weight_shape
)
//
h
w
=
np
.
prod
(
self
.
_weight_shape
)
//
h
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
1aa2bde0
...
@@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
...
@@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
# create intput and parameters
# create intput and parameters
inputs = {'Weight': weight}
inputs = {'Weight': weight}
input_shape = weight.shape
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), ("The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim))
h = input_shape[dim]
h = input_shape[dim]
w = np.prod(input_shape) // h
w = np.prod(input_shape) // h
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录