r/MachineLearning · · 1 min read

I built a Mamba1 variant I call SM1 with d_state=1 that runs on Blackwell in pure PyTorch [P]

Mirrored from r/MachineLearning for archival readability. Support the source by reading on the original site.

On windows mamba-ssm is not easily available and doesn't compile on sm_120. SM1 (Scalar Mamba1) replaces the entire selective scan with two native PyTorch ops:

L = torch.cumprod(dA, dim=1)

h = L * (h0.unsqueeze(1) + torch.cumsum(dBx / L.clamp(min=1e-6), dim=1))

y = h * C

This is the exact closed-form solution to the d_state=1 recurrence via variation of parameters. Not an approximation, it is identical to sequential computation of floating point precision. d_state=2 breaks it. d_state=1 is the boundary where the closed form exists.

The Mamba1 scan intermediates are (B, T, F, S). SM1 eliminates S entirely, there is 16x less scan memory than a Mamba1 with d_state=16. The inference state for a 130M param model is about 14,080 floats, 56 KB, no KV cache, O(1) per token forever.

I am currently training it on 163K MIDI files, which is 2.5B tokens roughly in my custom format. 130M params fits in under half of my 16 GB card which is an RTX 5060 Ti.

submitted by /u/TechnoVoyager
[link] [comments]

Discussion (0)

Sign in to join the discussion. Free account, 30 seconds — email code or GitHub.

Sign in →

No comments yet. Sign in and be the first to say something.

More from r/MachineLearning