未验证 提交 e6d20888 编写于 作者: Z Zth9730 提交者: GitHub

支持0维Tensor需要的修改 (#2621)

上级 114f3802
......@@ -108,7 +108,7 @@ for epoch in range(1, epochs + 1):
optimizer.clear_grad()
# Calculate loss
avg_loss = loss.numpy()[0]
avg_loss = float(loss)
# Calculate metrics
preds = paddle.argmax(logits, axis=1)
......
......@@ -509,7 +509,7 @@
" optimizer.clear_grad()\n",
"\n",
" # Calculate loss\n",
" avg_loss += loss.numpy()[0]\n",
" avg_loss += float(loss)\n",
"\n",
" # Calculate metrics\n",
" preds = paddle.argmax(logits, axis=1)\n",
......
......@@ -101,7 +101,7 @@ if __name__ == "__main__":
optimizer.clear_grad()
# Calculate loss
avg_loss += loss.numpy()[0]
avg_loss += float(loss)
# Calculate metrics
preds = paddle.argmax(logits, axis=1)
......
......@@ -110,7 +110,7 @@ if __name__ == '__main__':
optimizer.clear_grad()
# Calculate loss
avg_loss += loss.numpy()[0]
avg_loss += float(loss)
# Calculate metrics
num_corrects += corrects
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import torch
from paddle.device.cuda import synchronize
from parallel_wavegan.layers import residual_block
from parallel_wavegan.layers import upsample
from parallel_wavegan.models import parallel_wavegan as pwgan
......@@ -24,7 +25,6 @@ from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
from paddlespeech.t2s.models.parallel_wavegan import ResidualBlock
from paddlespeech.t2s.models.parallel_wavegan import ResidualPWGDiscriminator
from paddlespeech.t2s.utils.layer_tools import summary
from paddlespeech.t2s.utils.profile import synchronize
paddle.set_device("gpu:0")
device = torch.device("cuda:0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册