diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 9bf1ad57e..510e3485b 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -29,7 +29,13 @@ from funasr.train_utils.load_pretrained_model import load_pretrained_model from funasr.utils import export_utils from funasr.utils import misc - +def is_npu_available(): + """检查NPU是否可用。""" + try: + import torch_npu + return torch_npu.npu.is_available() + except ImportError: + return False def _resolve_ncpu(config, fallback=4): """Return a positive integer representing CPU threads from config.""" @@ -199,6 +205,7 @@ def build_model(**kwargs): if ((device =="cuda" and not torch.cuda.is_available()) or (device == "xpu" and not torch.xpu.is_available()) or (device == "mps" and not torch.backends.mps.is_available()) + or (device == "npu" and not is_npu_available()) or kwargs.get("ngpu", 1) == 0): device = "cpu" kwargs["batch_size"] = 1