diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 5b80a8af..c57e081e 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -661,7 +661,7 @@ class UniPC: if x_t is None: if use_predictor: - pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res @@ -669,7 +669,7 @@ class UniPC: if use_corrector: model_t = self.model_fn(x_t, t) if D1s is not None: - corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) else: corr_res = 0 D1_t = (model_t - model_prev_0)