NVIDIA Developer Blog · · 9 min read

Train Models Faster with JAX and MaxText Using NVFP4 on NVIDIA Blackwell

Mirrored from NVIDIA Developer Blog for archival readability. Support the source by reading on the original site.

Train Models Faster with JAX and MaxText Using NVFP4 on NVIDIA Blackwell

Decorative image.

AI-Generated Summary

Like
Dislike
  • NVFP4 enables high-throughput, 4-bit mixed-precision pre-training on NVIDIA Blackwell and Rubin platforms, achieving up to 1.73x speedup over FP8 baselines with negligible accuracy loss by leveraging subbyte precision and native hardware support in the NVIDIA GB300 Grace Blackwell Ultra Superchip.
  • The NVFP4 training recipe for JAX (as implemented in MaxText) preserves convergence in large-scale LLM training through five core techniques: 16-element micro block scaling, E4M3 block scale factors under a per-tensor FP32 scale, selective Random Hadamard Transform for WGRAD inputs, 2D FP8 scaling per 16x16 weight block, and stochastic rounding for unbiased quantization.
  • GEMM operations in transformer MLP layers are quantized to NVFP4, while attention blocks remain at higher precision to avoid softmax quantization noise amplification; empirical results on Llama 3 8B and Llama 3.1 405B show significant throughput gains on NVIDIA GB200 and GB300 hardware without measurable degradation in final model loss.

AI-generated content may summarize information incompletely. Verify important information. Learn more

Pre-training frontier LLMs comes down to throughput. When training spans trillions of tokens across thousands of accelerators, every percentage point of step time can add up to days of training and substantial compute costs. Numerical precision is one of the highest-leverage knobs available, but low- bit mixed-precision pretraining is hard to get right.

To address this, the NVFP4 training recipe in TransformerEngine uses subbyte precision for JAX pretraining. For an end-to-end example, see the recipe in MaxText, a high-performance, scalable LLM framework library. The result is high-throughput, 4-bit mixed-precision pre-training on NVIDIA Blackwell with no measurable accuracy loss compared to the FP8 baseline.

This post explains the NVFP4 format and how it’s built to achieve high performance and accuracy at ultra-low precision. It also shows how to apply a MaxText NVFP4 pretraining recipe and collect performance data showing performance gains. For methodology details, see the NVFP4 pretraining paper.  

NVFP4 format and benefits

This NVFP4 introductory post explains its format and how two-level microscaling encodes higher signals with less error than other microscaling formats. It also explains how native hardware support of NVFP4 on the NVIDIA GB300 Grace Blackwell Ultra Superchip delivers 7x GEMM throughput compared to native FP8 precision on the NVIDIA Hopper. That higher throughput, along with the NVFP4 pretraining recipe, shortens training step time with negligible accuracy loss. This enables AI factories to train more and larger models within the same time budget, or train models faster with a shorter time budget.

NVFP4 pretraining recipe

The NVFP4 recipe combines several ingredients that together preserve convergence while unlocking NVIDIA Blackwell and the NVIDIA Rubin platform NVFP4 throughput. To enable efficient narrow-precision training, the pretraining recipe uses several key techniques that have been chosen based on their performance and accuracy.  

Five key ingredients work together while maintaining the accuracy required in 4-bit pretraining:

  • Micro block scaling uses 16-element blocks, half the size of MXFP4‘s 32-element blocks, so a single outlier has less influence on the shared scale.
  • E4M3 block scale factors uses mantissa bits instead of MXFP4’s power-of-two E8M0 scaling, layered under a per-tensor FP32 scale. In an 8B-parameter, 1T token experiment, MXFP4 requires ~36% more tokens to match NVFP4’s final loss.
  • Random Hadamard Transform applies only to WGRAD GEMM inputs to Gaussianize outliers. The recipe skips on FPROP and DGRAD because transforming those paths would also require transforming the weight, breaking 2D-scale consistency.
  • 2D weight scaling uses one FP8 scale per 16×16 weight block, so FPROP and its transposed DGRAD use the same scale. Activation and gradients keep lower-overhead 1×16 scaling.
  • Stochastic rounding uses unbiased rounding to keep tiny updates from being crushed to zero. Weights and activations stay on round-to-nearest-even, where SR would amplify error instead. Both modes are native to Blackwell FP4 conversion instructions.

Figure 1 shows the NVFP4 data flow inside one linear layer. 

The three GEMMs: FPROP (forward), DGRAD (activation gradient), and WGRAD (weight gradient) are quantized to NVFP4 only for the MLP (feed-forward) layers of the transformer; the GEMMs inside the attention block (QKV projection, attention output projection, and the score/context matmuls) remain in higher precision.  

NVFP4 is applied first to MLP layers because attention’s softmax exponentially amplifies quantization noise on QK^T scores. Attention activations also carry concentrated outliers that 4-bit precision can’t represent well. Because MLPs account for most training FLOPs, this captures most of the speedup without risking convergence.

A diagram to illustrate the compute flow for the NVFP4 quantized linear layer, in which all GEMM operations are quantized inputs to NVFP4. The diagram refers to the NVIDIA NVFP4 training paper.
Figure 1. Illustration of compute flow for an NVFP4 quantized linear layer. All GEMM operations quantize their inputs to NVFP4 (source: https://arxiv.org/abs/2509.25149)

All three MLP GEMMs consume NVFP4 inputs and emit BF16 outputs, which are eventually folded into an FP32 master weight at the optimizer step. The same path makes the recipe’s convergence‑preserving choices visible: 2D block quantization on the weights (consistent FPROP/DGRAD values across the transpose), a Random Hadamard Transform on the WGRAD inputs (flattens outliers before 4-bit quantization), and stochastic rounding on the gradient quantizers (keeps small updates unbiased).

Enabling NVFP4 in MaxText

The MaxText NVFP4 recipe is available in the JAX-Toolbox GitHub repository. The launch script trains Llama 3 8B with NVFP4 on Blackwell. To enable it, set the quantization flag in MaxText to switch into the NVFP4 path. Two modes are exposed:

  • quantization=te_nvfp4: NVFP4 with Random Hadamard Transform. Recommended when the convergence under te_nvfp4_no_rht is not satisfactory.
  • quantization=te_nvfp4_no_rht: NVFP4 without RHT. Lowest overhead, but may degrade convergence quality.

Run the example script from the MaxText repository root inside a container that has JAX, NVIDIA Transformer Engine, and the required NVIDIA CUDA/cuDNN libraries installed. The public NVIDIA MaxText container ghcr.io/nvidia/jax:maxtext is recommended.

The following is a partial example of the Llama3 8B MaxText NVFP4 training script, which declares the nvfp4 argument through Transformer Engine:

RUN_SETTINGS="-m maxtext.trainers.pre_train.train maxtext/configs/base.yml run_name=debug_run base_output_directory=./debug_logs hardware=gpu dataset_type=synthetic  model_name=llama3-8b remat_policy='minimal_with_context_and_quantization' scan_layers=False attention='cudnn_flash_te' steps=50 dtype=bfloat16 max_target_length=8192 per_device_batch_size=4 ici_data_parallelism=${ici_DP} dcn_data_parallelism=${dcn_DP} ici_fsdp_parallelism=${ici_FSDP} dcn_fsdp_parallelism=${dcn_FSDP} profiler=nsys enable_checkpointing=false override_model_config=True gradient_accumulation_steps=1 quantization=te_nvfp4_no_rht max_segments_per_seq=32"

After launch, MaxText prints step time, TFLOP/s/device, and tokens/s/device. An NVIDIA Nsight Systems trace is written to base_output_directory for inspection. To produce the FP8 baseline used in the comparison below, run the same script with quantization=te_fp8_delayedscaling

Performance results

The benchmark uses MaxText pre-training on Llama 3 8B with FSDP=4 with a sequence length of 8,192, a per-device batch size of 4, and 50 steps inside the public ghcr.io/nvidia/jax:maxtext container. 

Table 1 summarizes MaxText pretraining performance on the NVIDIA GB200 Grace Blackwell Superchip and NVIDIA GB300 Grace Blackwell Ultra Superchip for Llama 3 8B and Llama 3.1 405B, comparing the NVFP4 recipe against an FP8 baseline on the same hardware, parallelism, and global batch size. Numbers are measured at sequence length 8,192.

ModelHardware# GPUsFSDPMBSGBSSeq len
Llama3 8BGB200444168,192
Llama3 8BGB300444168,192
Llama 3.1 405BGB20012812811288,192
Llama 3.1 405BGB30012812811288,192
Table 1. Llama3 8B and Llama3.1 405B models NVFP4 vs FP8 pretraining recipe configuration on GB200 and GB300
ModelHardwareFP8 Per GPU TFLOPsNVFP4 Per GPU TFLOPsSpeedup vs FP8
Llama 3 8BGB200149720171.35×
Llama 3 8BGB300175923011.31×
Llama 3.1 405BGB200155722411.44×
Llama 3.1 405BGB300210336331.73×
Table 2. Llama3 8B and Llama3.1 405B models NVFP4 vs FP8 pretraining performance baseline configurations, measured on GB200 and GB300

Figure 2 shows per‑GPU sustained TFLOP/s across the four baseline configurations. NVFP4 delivers an additional 500–700 TF/s per GPU on every configuration. The 1.31–1.73x speedup over the FP8 baseline is from changing the GEMM precision while holding the model, hyperparameters, parallelism, and global batch size identical. 

The largest relative gains are with the 405B configurations (1.44x on GB200, 1.73x on GB300), where the per‑step GEMM mass dominates FSDP collective overhead and a precision‑level speedup translates directly into wall‑clock savings.

A chart to illustrate the pretraining throughput comparison between NVFP4 and FP8 on NVIDIA GB200 Grace Blackwell Superchip and NVIDIA GB300 Blackwell Ultra Superchip, in which NVFP4 shows from 1.31x to 1.73x performance gain.
Figure 2. Pretraining throughput NVFP4 vs FP8 baseline on GB200 and GB300

Figure 3 overlays Llama 3 8B training loss for the FP8 baseline and NVFP4 across 10,000 pretraining steps with otherwise identical hyperparameters. Both runs descend the same curve from ≈12.2 nats to ≈3.9 nats, with a converged‑regime mean gap of just +0.026 nats, well inside step‑to‑step noise. The NVFP4 speedups in Figure 2 come with no measurable accuracy cost.

A chart illustrating the loss curve of llama3 8B pretraining using NVFP4 and FP8, The curves show NVFP4 tracks the FP8 baseline along the 10 training steps.
Figure 3. Llama3 8B pretraining, NVFP4 tracks FP8 baseline loss curve (C4 dataset, ~10k steps)

Get started

Pull the MaxText container, run nvfp4_example.sh on Blackwell to get started.

Acknowledgments 

For their contributions to NVFP4 enablement in JAX, XLA, and TE, special thanks to Jaroslav Sevcik, Ilia Sergachev, Johannes Reifferscheid, Phuong Nguyen, and Jeremy Berchtold.

Discuss (0)

Tags

Agentic AI / Generative AI | Data Center / Cloud | Developer Tools & Techniques | General | Blackwell | Intermediate Technical | Deep dive | NVFP4

About the Authors

Avatar photo
About Max Xu
Max Xu is a senior technical lead at NVIDIA specializing in AI training and inference at scale, performance engineering, and end-to-end application deployment. He brings full-stack GPU expertise spanning from chip design, CUDA and kernel-level development to server and cloud for model training and inference, translating innovations into real-world impact. Before NVIDIA, Max worked in engineering roles across major CSP and semiconductor companies.
Avatar photo
About Haixin Liu
Haixin Liu is a senior deep learning architect for JAX performance at NVIDIA, where he focuses on AI training and inference efficiency at scale. Prior to NVIDIA, Haixin was at Meta working on AI System Co-Design for large recommendation models, and he was one of the original contributors of PyTorch Quantization. Haixin holds a PhD in Electrical and Computer Engineering from Purdue University.
Avatar photo
About Abhinav Goel
Abhinav Goel is a senior deep learning architect and the technical lead for JAX performance at NVIDIA. He has published over twenty peer-reviewed research articles and holds three patents. He earned a Ph.D. from the Elmore Family School of Electrical and Computer Engineering at Purdue University.
Avatar photo
About Tejash Shah
Tejash Shah is a principal product manager within the AI Platform Software group at NVIDIA, responsible for managing JAX and MLX frameworks. Before NVIDIA, Tejash held software engineering roles at semiconductor companies. He holds five patents in a wide range of technological domains. He earned a master's degree in Computer Science from The University of Texas at Dallas and a bachelor’s degree in Information Technology from Gujarat University.

Comments

Discussion (0)

Sign in to join the discussion. Free account, 30 seconds — email code or GitHub.

Sign in →

No comments yet. Sign in and be the first to say something.

More from NVIDIA Developer Blog