// LLM knowledge base for Apple's MLX framework + mlx-lm + on-device Gemma 3 4B
#id, a type badge, a one-paragraph description, and chips to neighbors. Load this page as context when an LLM needs to reason about MLX.
┌──────────────────────── mlx ────────────────────────┐
│ core (arrays · ops · dtypes) │
│ nn (Module · Linear · Conv · Transformer · ...) │
│ optimizers (Adam · AdamW · SGD · ...) │
│ utils · random · fft · linalg · fast │
└──────────────────────────┬──────────────────────────┘
│
┌───────────────┴───────────────┐
▼ ▼
mlx-lm mlx-vlm
────── ──────
load() → (model, tok) load() → (model, processor)
generate() / stream_generate generate() with images
convert · lora · fuse · server Gemma 3 multimodal · Llava · ...
target: M-series (M1+) · macOS 13.5+ · unified memory · Metal GPU
The fundamental tensor type. Lazy by default; lives in unified memory accessible from CPU and GPU. Constructed from Python lists, NumPy arrays, or directly via mx.zeros / mx.ones / mx.random.*.
Force evaluation of one or more lazy arrays. Critical to call after every training step or memory grows unbounded.
mx.float32 · mx.float16 · mx.bfloat16 · mx.int8 · mx.int16 · mx.int32 · mx.int64 · mx.bool_. bfloat16 is the inference default for LLMs.
normal · uniform · randint · bernoulli · truncated_normal · categorical. Stateful PRNG, seeded via mx.random.seed(n).
norm · qr · svd · cholesky · inv · solve. Standard linear algebra ops.
FFT routines (fft, ifft, rfft, irfft, 2D variants). Used in audio + signal pipelines.
Fused kernels: mx.fast.scaled_dot_product_attention, mx.fast.layer_norm, mx.fast.rms_norm, mx.fast.rope. The hand-tuned hot paths for transformer inference.
Pin ops to a specific device (mx.cpu, mx.gpu) or stream. Mostly automatic; explicit only for benchmarking.
Base class for all layers. Subclass with __init__ declaring submodules + __call__ defining forward. Auto-registers parameters.
Standard building blocks. Same shape as PyTorch — nn.Linear(in_features, out_features), etc.
Normalization layers. RMSNorm is what every modern LLM uses; mx.fast.rms_norm is the kernelized fast path.
Multi-head self-attention. For LLM serving, prefer mx.fast.scaled_dot_product_attention wired into a custom block.
Activation functions. nn.silu (a.k.a. swish) is the LLM default.
Returns a function that computes (loss, grads-wrt-params). Pair with opt.update(model, grads).
The default. AdamW(learning_rate=1e-4, weight_decay=0.01). Decoupled weight decay; what every LLM trainer uses.
Drop-in alternatives. Lion is a single-momentum optimizer that's competitive on transformers with lower memory.
Learning-rate schedules. Pass into the optimizer's learning_rate=.
Wraps a Python function; returns a fn that computes the gradient w.r.t. the first argument. JAX-style.
Like grad but also returns the original output — saves a recomputation.
Vectorizing map. Lifts a single-example function to a batched one without writing a batch dim.
JIT-compiles a function. Enables kernel fusion + reduces dispatch overhead. Worth doing on the hot path of training/inference.
Returns (model, tokenizer) from an HF or local path. Auto-handles MLX-format weights, sharding, dtype.
One-shot text generation. Args: prompt, max_tokens, temp, top_p, top_k, repetition_penalty, prompt_cache.
Yields token chunks as they're generated. Use for live UIs and to back-pressure consumers.
mlx_lm.convert --hf-path <repo> --mlx-path ./out --quantize --q-bits 4. HF → MLX with optional quantization.
LoRA / DoRA fine-tuning. --train --data ./data --iters 1000 --lora-layers 16. Trainable on M-series with 16 GB+.
Merge an adapter into the base weights. Output is a single self-contained model — no adapter path needed at inference.
OpenAI-compatible HTTP server. --model <path> --host 0.0.0.0 --port 8080. Drop-in for any OpenAI client.
Reuse KV across calls when prompts share a prefix (system prompts, RAG context). make_prompt_cache(model); pass to generate().
Returns (model, processor). Processor handles image preprocessing alongside tokenization.
Generate text conditioned on one or more images. Uses model's chat template via apply_chat_template.
Google's open-weight models. Sizes: 270M, 1B, 4B, 12B, 27B. Decoder-only transformer; multimodal at 4B+. Permissively licensed (gated, but open).
The local-on-Mac sweet spot. ~8 GB at bf16, ~2.5 GB at 4-bit. Fast on M2/M3/M4. -it variant is instruction-tuned and multimodal.
gemma-3-4b-it includes vision. Pair with mlx-vlm for image-conditioned generation. Used for moment scoring + frame-level caption polish.
Pre-converted MLX weights for popular open models (Gemma, Llama, Qwen, Mistral, Phi, DeepSeek). bf16 + 4bit + 8bit variants.
Llama 3 / 3.1 / 3.2 · Qwen 2.5 / 3 · Mistral · Phi-3 / 3.5 / 4 · DeepSeek · OLMo · Granite · Yi · Mixtral · Stable LM. Same load/generate API.
Interactive REPL with chat-template handling. --model <path>. Good for sanity-checking a freshly converted model.
Pre-compute a KV cache for a long prompt to reuse across generate calls. Useful for RAG with stable context.
Inspect / clean the local HF cache used by mlx-lm — list models, evict by size or age.
One address space; CPU and GPU read the same buffers. No copy across devices. The architectural reason MLX exists vs adapting PyTorch.
Operations build a graph; computation defers until mx.eval() or a Python-side read. Lets the compiler fuse ops; can also leak memory if you forget to materialize.
Hand-tuned Metal kernels for the LLM hot path: SDPA, RMSNorm, RoPE, layer-norm. Found under mx.fast.*; what makes mlx-lm fast.
4-bit / 8-bit weight compression with group-size scaling. --q-bits 4 --q-group-size 64 is the standard. ~3-4× memory reduction, ~5% quality cost.
Low-rank adapters. Train a small additive matrix per layer, leave base weights frozen. Fits a 4B fine-tune in <16 GB unified memory.
bf16 weights × fp32 input casts everything to fp32. Cast inputs at the model boundary: x.astype(mx.bfloat16).
Stores attention K + V for past tokens. Dominates memory at long contexts. Cap with max_kv_size or compress with kv_bits=8.
mlx_lm.server exposes /v1/chat/completions + /v1/completions. Mostly drop-in for OpenAI SDKs; tool-calling + image-content edge cases differ.
Always mx.eval(model.parameters(), opt.state) after every optimizer step. Otherwise memory and graph debt grow without bound.
Manages the HF cache MLX reads from. huggingface-cli login required for gated models (Gemma).
High-throughput data loaders that don't bottleneck on Python. Used in training pipelines.
Graph neural networks built on MLX.
Swift bindings — for shipping MLX models inside iOS / macOS apps. Same array semantics.
mlx-lm source. Examples, training scripts, server, model registry.
Vision-language models on MLX. Gemma 3 multimodal, Llava, Qwen-VL.
Reference implementations — transformer LM training, MNIST, stable diffusion, whisper. Read these for idiomatic MLX.
Narrative how-to companion to this wiki — install, concepts, arrays, nn, training, mlx-lm, Gemma 3 4B walkthrough, quantization, LoRA, server, integrate.