提交 05816bdf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!555 optimize mindwizard prompt hints

Merge pull request !555 from liangyongxiong/wizard
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* Web UI supports language internationalization, including both Chinese and English. * Web UI supports language internationalization, including both Chinese and English.
## Bugfixes ## Bugfixes
* Optimize UI page initialization to handle timeout requests. [!503](https://gitee.com/mindspore/mindinsight/pulls/503) * Optimize UI page initialization to handle timeout requests. ([!503](https://gitee.com/mindspore/mindinsight/pulls/503))
* Fix the line break problem when the profiling file number is too long. [532](https://gitee.com/mindspore/mindinsight/pulls/532) * Fix the line break problem when the profiling file number is too long. ([!532](https://gitee.com/mindspore/mindinsight/pulls/532))
## Thanks to our Contributors ## Thanks to our Contributors
Thanks goes to these wonderful people: Thanks goes to these wonderful people:
......
...@@ -28,7 +28,7 @@ def cli_entry(): ...@@ -28,7 +28,7 @@ def cli_entry():
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='wizard', prog='mindwizard',
description='MindWizard CLI entry point (version: {})'.format(mindinsight.__version__)) description='MindWizard CLI entry point (version: {})'.format(mindinsight.__version__))
parser.add_argument( parser.add_argument(
......
...@@ -41,7 +41,7 @@ if __name__ == "__main__": ...@@ -41,7 +41,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num') parser.add_argument('--device_num', type=int, default=1, help='Device num')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--dataset_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path')
......
...@@ -33,7 +33,7 @@ from src.lenet import LeNet5 ...@@ -33,7 +33,7 @@ from src.lenet import LeNet5
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./Data", parser.add_argument('--dataset_path', type=str, default="./Data",
help='path where the dataset is saved') help='path where the dataset is saved')
......
...@@ -36,7 +36,7 @@ if __name__ == "__main__": ...@@ -36,7 +36,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./Data", parser.add_argument('--dataset_path', type=str, default="./Data",
help='path where the dataset is saved') help='path where the dataset is saved')
...@@ -45,9 +45,6 @@ if __name__ == "__main__": ...@@ -45,9 +45,6 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.device_target == "CPU":
args.dataset_sink = False
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ckpt_save_dir = './' ckpt_save_dir = './'
if args.run_distribute: if args.run_distribute:
......
...@@ -35,7 +35,7 @@ parser.add_argument('--run_distribute', type=bool, default=False, help='Run dist ...@@ -35,7 +35,7 @@ parser.add_argument('--run_distribute', type=bool, default=False, help='Run dist
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target: "Ascend", "GPU", "CPU"') parser.add_argument('--device_target', type=str, default='Ascend', help='Device target: "Ascend", "GPU"')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'],
help='DataSet sink mode is True or False') help='DataSet sink mode is True or False')
......
...@@ -59,7 +59,8 @@ class CreateProject(BaseCommand): ...@@ -59,7 +59,8 @@ class CreateProject(BaseCommand):
def _check_project_dir(project_name): def _check_project_dir(project_name):
"""Check project directory whether empty or exist.""" """Check project directory whether empty or exist."""
if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name): if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name):
raise CommandError("'%s' is not a valid project name. Please input a valid name" % project_name) raise CommandError("'%s' is not a valid project name. Please input a valid name matching "
"regex ^[A-Za-z0-9][A-Za-z0-9._-]*$" % project_name)
project_dir = os.path.join(os.getcwd(), project_name) project_dir = os.path.join(os.getcwd(), project_name)
if os.path.exists(project_dir): if os.path.exists(project_dir):
output_path = Path(project_dir) output_path = Path(project_dir)
...@@ -81,19 +82,23 @@ class CreateProject(BaseCommand): ...@@ -81,19 +82,23 @@ class CreateProject(BaseCommand):
'\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1)) '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1))
) )
prompt_type = click.IntRange(min=1, max=len(network_type_choices)) prompt_type = click.IntRange(min=1, max=len(network_type_choices))
choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, choice = 0
confirmation_prompt=False, while not choice:
choice = click.prompt(prompt_msg, default=0, type=prompt_type,
hide_input=False, show_choices=False,
confirmation_prompt=False, show_default=False,
value_proc=lambda x: process_prompt_choice(x, prompt_type)) value_proc=lambda x: process_prompt_choice(x, prompt_type))
if not choice:
click.secho(textwrap.dedent("Network is required."), fg='red')
return network_type_choices[choice - 1] return network_type_choices[choice - 1]
@staticmethod @staticmethod
def echo_notice(): def echo_notice():
"""Echo notice for depending environment.""" """Echo notice for depending environment."""
click.secho(textwrap.dedent(""" click.secho(textwrap.dedent(
[NOTICE] To ensure the final generated scripts run under specific environment with the following "[NOTICE] The final generated scripts should be run under environment "
"where mindspore==%s and related device drivers are installed. " % SUPPORT_MINDSPORE_VERSION), fg='yellow')
mindspore : %s
""" % SUPPORT_MINDSPORE_VERSION), fg='red')
def run(self, args): def run(self, args):
"""Override run method to start.""" """Override run method to start."""
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""GenericNetwork module.""" """GenericNetwork module."""
import os import os
import textwrap
import click import click
...@@ -93,6 +94,7 @@ class GenericNetwork(BaseNetwork): ...@@ -93,6 +94,7 @@ class GenericNetwork(BaseNetwork):
choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
confirmation_prompt=False, default=default_choice, confirmation_prompt=False, default=default_choice,
value_proc=lambda x: process_prompt_choice(x, prompt_type)) value_proc=lambda x: process_prompt_choice(x, prompt_type))
click.secho(textwrap.dedent("Your choice is %s." % choice_contents[choice - 1]), fg='yellow')
return choice_contents[choice - 1] return choice_contents[choice - 1]
def ask_loss_function(self): def ask_loss_function(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册