Add experimental --async-offload lowvram weight offloading. (#7820)
This should speed up the lowvram mode a bit. It currently is only enabled when --async-offload is used but it will be enabled by default in the future if there are no problems.
This commit is contained in:
parent
b685b8a4e0
commit
0dcc75ca54
@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
|||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
|
|
||||||
|
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
|
@ -939,15 +939,56 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
|
||||||
|
STREAMS = {}
|
||||||
|
NUM_STREAMS = 1
|
||||||
|
if args.async_offload:
|
||||||
|
NUM_STREAMS = 2
|
||||||
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
stream_counter = 0
|
||||||
|
def get_offload_stream(device):
|
||||||
|
global stream_counter
|
||||||
|
if NUM_STREAMS <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if device in STREAMS:
|
||||||
|
ss = STREAMS[device]
|
||||||
|
s = ss[stream_counter]
|
||||||
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
if is_device_cuda(device):
|
||||||
|
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||||
|
return s
|
||||||
|
elif is_device_cuda(device):
|
||||||
|
ss = []
|
||||||
|
for k in range(NUM_STREAMS):
|
||||||
|
ss.append(torch.cuda.Stream(device=device, priority=10))
|
||||||
|
STREAMS[device] = ss
|
||||||
|
s = ss[stream_counter]
|
||||||
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
return s
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sync_stream(device, stream):
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
if is_device_cuda(device):
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
return weight
|
return weight
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
if stream is not None:
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
with stream:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
|
@ -37,20 +37,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = len(s.bias_function) > 0
|
has_function = len(s.bias_function) > 0
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||||
if has_function:
|
if has_function:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
has_function = len(s.weight_function) > 0
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||||
if has_function:
|
if has_function:
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user