diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index ee013d79e6ed3d39516ee65d5c4df5ec30a24b42..ae43bf82cb6b258ac87ee55221e94421d7b203d1 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -37,6 +37,11 @@ if __name__ == "__main__": "--export_path", type=str, help="path of the jit model to save") parser.add_argument( "--model_type", type=str, default='offline', help="offline/online") + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print("model_type:{}".format(args.model_type)) print_arguments(args) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 388b380d1c78aeb45970486091285f4c1248eb55..f29f508326100d1d66312754fa1eee8cbacee844 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -37,6 +37,11 @@ if __name__ == "__main__": # save asr result to parser.add_argument( "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print_arguments(args, globals()) print("model_type:{}".format(args.model_type)) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index 707eb9e1bc26204fe5b6a9070e02f7ad95d5f334..c136ddf29f317c00409125cdcbc7c7a970a7095e 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -40,6 +40,11 @@ if __name__ == "__main__": "--export_path", type=str, help="path of the jit model to save") parser.add_argument( "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") parser.add_argument( "--enable-auto-log", action="store_true", help="use auto log") args = parser.parse_args() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index e2c68d4be9ae2de9875c4a95d06c6542fd397ce3..cb4867ef25c03e4a1ccb4bb52547be9e10114a82 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -33,6 +33,11 @@ if __name__ == "__main__": parser = default_argument_parser() parser.add_argument( "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 84da251aa062b3a82f0c4d1a4f2c012361b86ae6..d30556ca1bdf9ea375743fdc051b646944297b26 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -112,7 +112,13 @@ class Trainer(): logger.info(f"Rank: {self.rank}/{self.world_size}") # set device - paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + if self.args.ngpu == 0: + if self.args.nxpu == 0: + paddle.set_device('cpu') + else: + paddle.set_device('xpu') + elif self.args.ngpu > 0: + paddle.set_device("gpu") if self.parallel: self.init_parallel()