From 1e91f7da354c61e7b12740240524e0ce0330e689 Mon Sep 17 00:00:00 2001 From: Zhangjingyu06 Date: Tue, 24 May 2022 10:10:38 +0000 Subject: [PATCH] deepspeech2 modify for kunlun --- paddlespeech/s2t/exps/deepspeech2/bin/export.py | 5 +++++ paddlespeech/s2t/exps/deepspeech2/bin/test.py | 5 +++++ paddlespeech/s2t/exps/deepspeech2/bin/test_export.py | 5 +++++ paddlespeech/s2t/exps/deepspeech2/bin/train.py | 5 +++++ paddlespeech/s2t/training/trainer.py | 8 +++++++- 5 files changed, 27 insertions(+), 1 deletion(-) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index ee013d79..ae43bf82 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -37,6 +37,11 @@ if __name__ == "__main__": "--export_path", type=str, help="path of the jit model to save") parser.add_argument( "--model_type", type=str, default='offline', help="offline/online") + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print("model_type:{}".format(args.model_type)) print_arguments(args) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 388b380d..f29f5083 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -37,6 +37,11 @@ if __name__ == "__main__": # save asr result to parser.add_argument( "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print_arguments(args, globals()) print("model_type:{}".format(args.model_type)) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index 707eb9e1..c136ddf2 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -40,6 +40,11 @@ if __name__ == "__main__": "--export_path", type=str, help="path of the jit model to save") parser.add_argument( "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") parser.add_argument( "--enable-auto-log", action="store_true", help="use auto log") args = parser.parse_args() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index e2c68d4b..cb4867ef 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -33,6 +33,11 @@ if __name__ == "__main__": parser = default_argument_parser() parser.add_argument( "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 84da251a..d30556ca 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -112,7 +112,13 @@ class Trainer(): logger.info(f"Rank: {self.rank}/{self.world_size}") # set device - paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + if self.args.ngpu == 0: + if self.args.nxpu == 0: + paddle.set_device('cpu') + else: + paddle.set_device('xpu') + elif self.args.ngpu > 0: + paddle.set_device("gpu") if self.parallel: self.init_parallel() -- GitLab