diff --git a/main.py b/main.py index a162e1ed..54c66dac 100644 --- a/main.py +++ b/main.py @@ -209,6 +209,7 @@ class PromptExecutor: executed = set(executed) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) + torch.cuda.empty_cache() def validate_inputs(prompt, item): unique_id = item