提交 8b2b59ad 编写于 作者: A Aston Zhang

fix gram matrix in style transfer

上级 d41eea0c
......@@ -123,7 +123,7 @@ def content_loss(y_hat, y):
def gram(x):
c, n = x.shape[1], x.size // x.shape[1]
y = x.reshape((c, n))
return nd.dot(y, y.T) / n
return nd.dot(y, y.T) / (c * n)
```
和对应的损失函数,这里假设样式图像的样式特征协方差已经预先计算好了。
......@@ -147,7 +147,7 @@ def tv_loss(y_hat):
```{.python .input n=12}
style_channels = [net[l].weight.shape[0] for l in style_layers]
style_weights = [1e4 / c**2 for c in style_channels]
style_weights = [1e4] * len(style_channels)
content_weights, tv_weight = [1], 10
```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册