r/LocalLLaMA · · 7 min read

Anyone using Flash Attention 2 (ai-bond) on their V100's? How is the performance?

Mirrored from r/LocalLLaMA for archival readability. Support the source by reading on the original site.

I just Installed Flash Attention 2 from here: https://github.com/ai-bond/flash-attention-v100"

I did some basic benchmarks and I am getting from 4x-7x memory utilization. However, benchmarks don't always translate to real world scenarios.

**I have noticed that the thinking time before answering has been minimized.

Here are some of my results: Test: B=1, H=1, M=128, N=128, D=128, causal=True 

✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 17.1 MB, PyTorch: 17.6 MB (Δ: -0.5 MB, -3.1%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.63x speedup) (bwd): Custom: 0.10ms, PyTorch: 2.48ms (24.31x speedup) (tot): Custom: 0.20ms, PyTorch: 3.38ms (17.28x speedup) Validation:

(Fwd): dO err=9.77e-04 ≤ 2×9.77e-04 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=9.77e-04 ≤ 3×1.95e-03

Test: B=1, H=1, M=256, N=256, D=256, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 19.3 MB, PyTorch: 21.4 MB (Δ: -2.1 MB, -9.9%) (fwd): Custom: 0.10ms, PyTorch: 0.67ms (7.06x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.18ms (18.49x speedup) (tot): Custom: 0.21ms, PyTorch: 2.85ms (13.38x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×4.88e-04 dK err=4.88e-04 ≤ 3×4.88e-04 dV err=4.88e-04 ≤ 3×9.77e-04

Test: B=1, H=1, M=256, N=256, D=256, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 19.6 MB, PyTorch: 21.8 MB (Δ: -2.2 MB, -10.0%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.57x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.29ms (19.64x speedup) (tot): Custom: 0.21ms, PyTorch: 3.19ms (15.14x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×9.77e-04 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03

Test: B=1, H=16, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 28.5 MB, PyTorch: 351.9 MB (Δ: -323.4 MB, -91.9%) (fwd): Custom: 0.28ms, PyTorch: 0.94ms (3.36x speedup) (bwd): Custom: 0.70ms, PyTorch: 2.46ms (3.53x speedup) (tot): Custom: 0.98ms, PyTorch: 3.40ms (3.48x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04

Test: B=1, H=16, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 30.0 MB, PyTorch: 354.4 MB (Δ: -324.4 MB, -91.5%) (fwd): Custom: 0.20ms, PyTorch: 1.30ms (6.38x speedup) (bwd): Custom: 0.41ms, PyTorch: 3.06ms (7.42x speedup) (tot): Custom: 0.62ms, PyTorch: 4.36ms (7.07x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03

Test: B=1, H=32, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 41.8 MB, PyTorch: 688.5 MB (Δ: -646.8 MB, -93.9%) (fwd): Custom: 0.45ms, PyTorch: 1.35ms (3.03x speedup) (bwd): Custom: 1.15ms, PyTorch: 3.77ms (3.29x speedup) (tot): Custom: 1.59ms, PyTorch: 5.12ms (3.21x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04

Test: B=1, H=32, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 43.8 MB, PyTorch: 691.5 MB (Δ: -647.8 MB, -93.7%) (fwd): Custom: 0.35ms, PyTorch: 2.01ms (5.72x speedup) (bwd): Custom: 0.76ms, PyTorch: 5.09ms (6.72x speedup) (tot): Custom: 1.11ms, PyTorch: 7.10ms (6.40x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03

Test: B=1, H=16, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 43.5 MB, PyTorch: 370.4 MB (Δ: -326.9 MB, -88.3%) (fwd): Custom: 0.25ms, PyTorch: 0.93ms (3.74x speedup) (bwd): Custom: 0.69ms, PyTorch: 2.37ms (3.43x speedup) (tot): Custom: 0.94ms, PyTorch: 3.30ms (3.51x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×1.22e-03 dK err=2.44e-04 ≤ 3×1.22e-03 dV err=2.44e-04 ≤ 3×9.77e-04

Test: B=1, H=16, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 43.5 MB, PyTorch: 371.4 MB (Δ: -327.9 MB, -88.3%) (fwd): Custom: 0.18ms, PyTorch: 1.26ms (7.09x speedup) (bwd): Custom: 0.45ms, PyTorch: 3.00ms (6.61x speedup) (tot): Custom: 0.63ms, PyTorch: 4.26ms (6.75x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03

Test: B=1, H=32, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 66.8 MB, PyTorch: 720.5 MB (Δ: -653.8 MB, -90.7%) (fwd): Custom: 0.46ms, PyTorch: 1.44ms (3.16x speedup) (bwd): Custom: 1.38ms, PyTorch: 3.93ms (2.85x speedup) (tot): Custom: 1.84ms, PyTorch: 5.37ms (2.93x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×1.22e-03 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×1.10e-03

Test: B=1, H=32, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 70.8 MB, PyTorch: 725.5 MB (Δ: -654.8 MB, -90.2%) (fwd): Custom: 0.30ms, PyTorch: 2.07ms (6.89x speedup) (bwd): Custom: 0.82ms, PyTorch: 5.27ms (6.46x speedup) (tot): Custom: 1.12ms, PyTorch: 7.34ms (6.58x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.46e-03 ≤ 3×2.93e-03 dK err=1.95e-03 ≤ 3×2.93e-03 dV err=1.95e-03 ≤ 3×1.95e-03

Test: B=1, H=16, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 70.5 MB, PyTorch: 404.4 MB (Δ: -333.9 MB, -82.6%) (fwd): Custom: 0.34ms, PyTorch: 1.02ms (2.97x speedup) (bwd): Custom: 1.00ms, PyTorch: 2.63ms (2.63x speedup) (tot): Custom: 1.34ms, PyTorch: 3.65ms (2.72x speedup)

Validation: (Fwd): dO err=1.22e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×7.32e-04 dK err=4.88e-04 ≤ 3×7.32e-04 dV err=2.44e-04 ≤ 3×4.88e-04

Test: B=1, H=16, M=1024, N=1024, D=64, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 70.5 MB, PyTorch: 405.4 MB (Δ: -334.9 MB, -82.6%) (fwd): Custom: 0.24ms, PyTorch: 1.38ms (5.73x speedup) (bwd): Custom: 0.69ms, PyTorch: 3.27ms (4.76x speedup) (tot): Custom: 0.93ms, PyTorch: 4.65ms (5.01x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03

Test: B=1, H=32, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 116.8 MB, PyTorch: 784.5 MB (Δ: -667.8 MB, -85.1%) (fwd): Custom: 0.57ms, PyTorch: 1.74ms (3.04x speedup) (bwd): Custom: 1.94ms, PyTorch: 4.81ms (2.48x speedup) (tot): Custom: 2.51ms, PyTorch: 6.54ms (2.61x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×9.77e-04

Test: B=1, H=32, M=1024, N=1024, D=64, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 124.8 MB, PyTorch: 793.5 MB (Δ: -668.8 MB, -84.3%) (fwd): Custom: 0.35ms, PyTorch: 2.37ms (6.74x speedup) (bwd): Custom: 1.15ms, PyTorch: 6.15ms (5.37x speedup) (tot): Custom: 1.50ms, PyTorch: 8.52ms (5.69x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03

Test: B=1, H=16, M=1024, N=1024, D=128, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 124.5 MB, PyTorch: 472.4 MB (Δ: -347.9 MB, -73.6%) (fwd): Custom: 0.56ms, PyTorch: 1.34ms (2.37x speedup) (bwd): Custom: 1.88ms, PyTorch: 3.69ms (1.96x speedup) (tot): Custom: 2.44ms, PyTorch: 5.03ms (2.06x speedup)

Validation: (Fwd): dO err=1.22e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×9.77e-04 dK err=2.44e-04 ≤ 3×1.46e-03 dV err=2.44e-04 ≤ 3×7.32e-04

Test: B=1, H=16, M=1024, N=1024, D=128, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 124.5 MB, PyTorch: 473.4 MB (Δ: -348.9 MB, -73.7%) (fwd): Custom: 0.38ms, PyTorch: 1.67ms (4.38x speedup) (bwd): Custom: 1.19ms, PyTorch: 4.36ms (3.66x speedup) (tot): Custom: 1.57ms, PyTorch: 6.03ms (3.84x speedup) Validation: (Fwd): dO err=1.95e-03 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03

Test: B=1, H=32, M=2048, N=2048, D=128, causal=False ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 401.3 MB, PyTorch: 3072.8 MB (Δ: -2671.5 MB, -86.9%) (fwd): Custom: 3.67ms, PyTorch: 9.60ms (2.61x speedup) (bwd): Custom: 14.67ms, PyTorch: 28.74ms (1.96x speedup) (tot): Custom: 18.34ms, PyTorch: 38.34ms (2.09x speedup)

Validation: (Fwd): dO err=2.44e-04 ≤ 2×6.10e-04 (Bwd): dQ err=2.44e-04 ≤ 3×9.77e-04 dK err=2.44e-04 ≤ 3×9.77e-04 dV err=2.44e-04 ≤ 3×7.93e-04

Test: B=1, H=32, M=2048, N=2048, D=128, causal=True ✅ Forward match OK ✅ Backward match OK

Performance: (Mem): Custom: 449.3 MB, PyTorch: 3124.8 MB (Δ: -2675.5 MB, -85.6%) (fwd): Custom: 2.05ms, PyTorch: 12.46ms (6.07x speedup) (bwd): Custom: 7.82ms, PyTorch: 33.30ms (4.26x speedup) (tot): Custom: 9.87ms, PyTorch: 45.76ms (4.64x speedup)

Validation: (Fwd): dO err=9.77e-04 ≤ 2×2.20e-03 (Bwd): dQ err=3.91e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03

Have you seen any major improvements?

submitted by /u/UltraFOV
[link] [comments]

Discussion (0)

Sign in to join the discussion. Free account, 30 seconds — email code or GitHub.

Sign in →

No comments yet. Sign in and be the first to say something.

More from r/LocalLLaMA