提交 c8e0ed60 编写于 作者: A Aston Zhang

simplify _get_batch

上级 d0442473
......@@ -134,7 +134,7 @@ print('output:', splitted)
```{.python .input n=6}
def train_batch(X, y, gpu_params, ctx, lr):
# 划分小批量数据样本并复制到各个 GPU 上。
# 当 ctx 包含多个GPU时,划分小批量数据样本并复制到各个 GPU 上。
gpu_Xs = split_and_load(X, ctx)
gpu_ys = split_and_load(y, ctx)
# 在各个 GPU 上计算损失。
......
......@@ -111,11 +111,7 @@ def evaluate_accuracy(data_iter, net, ctx=[mx.cpu()]):
def _get_batch(batch, ctx):
"""Return features and labels on ctx."""
if isinstance(batch, mx.io.DataBatch):
features = batch.data[0]
labels = batch.label[0]
else:
features, labels = batch
features, labels = batch
if labels.dtype != features.dtype:
labels = labels.astype(features.dtype)
return (gutils.split_and_load(features, ctx),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册