提交 703306ce 编写于 作者: L liuyuhui

fix Collective multi dataset_name

上级 70b6bbe8
...@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase): ...@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context): def build_network(self, context):
context["model"] = {} context["model"] = {}
if len(context["env"]["phase"]) > 1: if len(context["env"]["phase"]) > 1:
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
warnings.warn( warnings.warn(
"Cluster Train Only Support One Phase.", "Cluster Train Only Support One Phase.",
category=UserWarning, category=UserWarning,
...@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase): ...@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase):
context["model"][model_dict["name"]]["compiled_program"] = None context["model"][model_dict["name"]]["compiled_program"] = None
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type") type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "QueueDataset": if type == "QueueDataset":
raise ValueError( raise ValueError(
"Collective don't support QueueDataset training, please use DataLoader." "Collective don't support QueueDataset training, please use DataLoader."
) )
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][phase[
"name"]] = dataset_class.create_dataset(dataset["name"], "dataset_name"]] = dataset_class.create_dataset(
context) phase["dataset_name"], context)
context["status"] = "startup_pass" context["status"] = "startup_pass"
def _build_strategy(self, context): def _build_strategy(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册