diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1b971be3..f89a7aab 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 43e40224..d118f6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -939,15 +939,56 @@ def force_channels_last(): #TODO return False -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + +STREAMS = {} +NUM_STREAMS = 1 +if args.async_offload: + NUM_STREAMS = 2 + logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) + +stream_counter = 0 +def get_offload_stream(device): + global stream_counter + if NUM_STREAMS <= 1: + return None + + if device in STREAMS: + ss = STREAMS[device] + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + if is_device_cuda(device): + ss[stream_counter].wait_stream(torch.cuda.current_stream()) + return s + elif is_device_cuda(device): + ss = [] + for k in range(NUM_STREAMS): + ss.append(torch.cuda.Stream(device=device, priority=10)) + STREAMS[device] = ss + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + return s + return None + +def sync_stream(device, stream): + if stream is None: + return + if is_device_cuda(device): + torch.cuda.current_stream().wait_stream(stream) + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: return weight return weight.to(dtype=dtype, copy=copy) - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) + if stream is not None: + with stream: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + else: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) return r def cast_to_device(tensor, device, dtype, copy=False): diff --git a/comfy/ops.py b/comfy/ops.py index aae6cafa..62daf447 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -37,20 +37,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device + offload_stream = comfy.model_management.get_offload_stream(device) bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: for f in s.bias_function: bias = f(bias) has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: for f in s.weight_function: weight = f(weight) + + comfy.model_management.sync_stream(device, offload_stream) return weight, bias class CastWeightBiasOp: