未验证 提交 10f12f25 编写于 作者: 1 123malin 提交者: GitHub

Merge branch 'master' into gnn

......@@ -99,7 +99,8 @@ class SingleNetwork(NetworkBase):
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
if type == "QueueDataset":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
......@@ -133,9 +134,7 @@ class PSNetwork(NetworkBase):
if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader":
model._init_dataloader(is_infer=False)
data_loader = DataLoader(context)
data_loader.get_dataloader(context, dataset_name,
model._data_loader)
model.net(model._data_var, False)
optimizer = model.optimizer()
strategy = self._build_strategy(context)
......@@ -160,7 +159,11 @@ class PSNetwork(NetworkBase):
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
if type == "DataLoader":
data_loader = DataLoader(context)
data_loader.get_dataloader(context, dataset_name,
model._data_loader)
elif type == "QueueDataset":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
......@@ -229,9 +232,6 @@ class PslibNetwork(NetworkBase):
if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader":
model._init_dataloader(is_infer=False)
data_loader = DataLoader(context)
data_loader.get_dataloader(context, dataset_name,
model._data_loader)
model.net(model._data_var, False)
optimizer = model.optimizer()
......@@ -257,7 +257,11 @@ class PslibNetwork(NetworkBase):
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
if type == "DataLoader":
data_loader = DataLoader(context)
data_loader.get_dataloader(context, dataset_name, context[
"model"][model_dict["name"]]["model"]._data_loader)
elif type == "QueueDataset":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
......@@ -323,7 +327,10 @@ class CollectiveNetwork(NetworkBase):
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
if type == "QueueDataset":
raise ValueError(
"Collective don't support QueueDataset training, please use DataLoader."
)
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册