MLX Wiki

// LLM knowledge base for Apple's MLX framework + mlx-lm + on-device Gemma 3 4B

How to use this wiki (for humans and LLMs) Every module, function, concept, and CLI in MLX has one node. Each node has a stable #id, a type badge, a one-paragraph description, and chips to neighbors. Load this page as context when an LLM needs to reason about MLX.
   ┌──────────────────────── mlx ────────────────────────┐
   │  core (arrays · ops · dtypes)                       │
   │  nn   (Module · Linear · Conv · Transformer · ...)  │
   │  optimizers (Adam · AdamW · SGD · ...)              │
   │  utils · random · fft · linalg · fast               │
   └──────────────────────────┬──────────────────────────┘
                              │
              ┌───────────────┴───────────────┐
              ▼                               ▼
        mlx-lm                          mlx-vlm
        ──────                          ──────
        load() → (model, tok)           load() → (model, processor)
        generate() / stream_generate    generate() with images
        convert · lora · fuse · server  Gemma 3 multimodal · Llava · ...

   target: M-series (M1+) · macOS 13.5+ · unified memory · Metal GPU

mlx.core — arrays & ops

mx.arrayclass

The fundamental tensor type. Lazy by default; lives in unified memory accessible from CPU and GPU. Constructed from Python lists, NumPy arrays, or directly via mx.zeros / mx.ones / mx.random.*.

mx.eval(*xs)function

Force evaluation of one or more lazy arrays. Critical to call after every training step or memory grows unbounded.

dtypessubmodule

mx.float32 · mx.float16 · mx.bfloat16 · mx.int8 · mx.int16 · mx.int32 · mx.int64 · mx.bool_. bfloat16 is the inference default for LLMs.

mx.randomsubmodule

normal · uniform · randint · bernoulli · truncated_normal · categorical. Stateful PRNG, seeded via mx.random.seed(n).

mx.linalgsubmodule

norm · qr · svd · cholesky · inv · solve. Standard linear algebra ops.

mx.fftsubmodule

FFT routines (fft, ifft, rfft, irfft, 2D variants). Used in audio + signal pipelines.

mx.fastsubmodule

Fused kernels: mx.fast.scaled_dot_product_attention, mx.fast.layer_norm, mx.fast.rms_norm, mx.fast.rope. The hand-tuned hot paths for transformer inference.

mx.stream / mx.set_default_devicefunction

Pin ops to a specific device (mx.cpu, mx.gpu) or stream. Mostly automatic; explicit only for benchmarking.

mlx.nn — neural-network modules

nn.Moduleclass

Base class for all layers. Subclass with __init__ declaring submodules + __call__ defining forward. Auto-registers parameters.

nn.Linear · Conv1d · Conv2d · Embeddinglayers

Standard building blocks. Same shape as PyTorch — nn.Linear(in_features, out_features), etc.

nn.LayerNorm · RMSNorm · BatchNorm · GroupNormnorms

Normalization layers. RMSNorm is what every modern LLM uses; mx.fast.rms_norm is the kernelized fast path.

nn.MultiHeadAttentionlayer

Multi-head self-attention. For LLM serving, prefer mx.fast.scaled_dot_product_attention wired into a custom block.

nn.relu · gelu · silu · softmax · ...functions

Activation functions. nn.silu (a.k.a. swish) is the LLM default.

nn.value_and_grad(model, fn)transform

Returns a function that computes (loss, grads-wrt-params). Pair with opt.update(model, grads).

mlx.optimizers

optim.AdamWoptimizer

The default. AdamW(learning_rate=1e-4, weight_decay=0.01). Decoupled weight decay; what every LLM trainer uses.

optim.Adam · SGD · Lion · Adafactoroptimizer

Drop-in alternatives. Lion is a single-momentum optimizer that's competitive on transformers with lower memory.

optim.cosine_decay · linear_schedule · join_schedulesscheduler

Learning-rate schedules. Pass into the optimizer's learning_rate=.

Composable transforms

mx.grad(fn)transform

Wraps a Python function; returns a fn that computes the gradient w.r.t. the first argument. JAX-style.

mx.value_and_grad(fn)transform

Like grad but also returns the original output — saves a recomputation.

mx.vmap(fn)transform

Vectorizing map. Lifts a single-example function to a batched one without writing a batch dim.

mx.compile(fn)transform

JIT-compiles a function. Enables kernel fusion + reduces dispatch overhead. Worth doing on the hot path of training/inference.

mlx-lm — LLM stack

mlx_lm.load(repo)function

Returns (model, tokenizer) from an HF or local path. Auto-handles MLX-format weights, sharding, dtype.

mlx_lm.generate(model, tok, ...)function

One-shot text generation. Args: prompt, max_tokens, temp, top_p, top_k, repetition_penalty, prompt_cache.

mlx_lm.stream_generate(...)function

Yields token chunks as they're generated. Use for live UIs and to back-pressure consumers.

mlx_lm.convertcli

mlx_lm.convert --hf-path <repo> --mlx-path ./out --quantize --q-bits 4. HF → MLX with optional quantization.

mlx_lm.loracli

LoRA / DoRA fine-tuning. --train --data ./data --iters 1000 --lora-layers 16. Trainable on M-series with 16 GB+.

mlx_lm.fusecli

Merge an adapter into the base weights. Output is a single self-contained model — no adapter path needed at inference.

mlx_lm.servercli

OpenAI-compatible HTTP server. --model <path> --host 0.0.0.0 --port 8080. Drop-in for any OpenAI client.

prompt cacheconcept

Reuse KV across calls when prompts share a prefix (system prompts, RAG context). make_prompt_cache(model); pass to generate().

mlx-vlm — vision-language

mlx_vlm.load(repo)function

Returns (model, processor). Processor handles image preprocessing alongside tokenization.

mlx_vlm.generate(model, processor, prompt, image=...)function

Generate text conditioned on one or more images. Uses model's chat template via apply_chat_template.

Models — Gemma 3 4B focus

Gemma 3 familymodel family

Google's open-weight models. Sizes: 270M, 1B, 4B, 12B, 27B. Decoder-only transformer; multimodal at 4B+. Permissively licensed (gated, but open).

Gemma 3 4Bmodel

The local-on-Mac sweet spot. ~8 GB at bf16, ~2.5 GB at 4-bit. Fast on M2/M3/M4. -it variant is instruction-tuned and multimodal.

Gemma 3 4B (multimodal)model

gemma-3-4b-it includes vision. Pair with mlx-vlm for image-conditioned generation. Used for moment scoring + frame-level caption polish.

mlx-community/...HF org

Pre-converted MLX weights for popular open models (Gemma, Llama, Qwen, Mistral, Phi, DeepSeek). bf16 + 4bit + 8bit variants.

other supported architecturesfamily

Llama 3 / 3.1 / 3.2 · Qwen 2.5 / 3 · Mistral · Phi-3 / 3.5 / 4 · DeepSeek · OLMo · Granite · Yi · Mixtral · Stable LM. Same load/generate API.

CLIs at a glance

mlx_lm.generatecli

One-shot. --model <path> --prompt <str> --max-tokens N.

mlx_lm.chatcli

Interactive REPL with chat-template handling. --model <path>. Good for sanity-checking a freshly converted model.

mlx_lm.cache_promptcli

Pre-compute a KV cache for a long prompt to reuse across generate calls. Useful for RAG with stable context.

mlx_lm.managecli

Inspect / clean the local HF cache used by mlx-lm — list models, evict by size or age.

Concepts

unified memoryconcept

One address space; CPU and GPU read the same buffers. No copy across devices. The architectural reason MLX exists vs adapting PyTorch.

lazy evaluationconcept

Operations build a graph; computation defers until mx.eval() or a Python-side read. Lets the compiler fuse ops; can also leak memory if you forget to materialize.

fused kernelsconcept

Hand-tuned Metal kernels for the LLM hot path: SDPA, RMSNorm, RoPE, layer-norm. Found under mx.fast.*; what makes mlx-lm fast.

quantizationconcept

4-bit / 8-bit weight compression with group-size scaling. --q-bits 4 --q-group-size 64 is the standard. ~3-4× memory reduction, ~5% quality cost.

LoRA / DoRAconcept

Low-rank adapters. Train a small additive matrix per layer, leave base weights frozen. Fits a 4B fine-tune in <16 GB unified memory.

dtype gotchaconcept

bf16 weights × fp32 input casts everything to fp32. Cast inputs at the model boundary: x.astype(mx.bfloat16).

KV cacheconcept

Stores attention K + V for past tokens. Dominates memory at long contexts. Cap with max_kv_size or compress with kv_bits=8.

OpenAI-compatible servingconcept

mlx_lm.server exposes /v1/chat/completions + /v1/completions. Mostly drop-in for OpenAI SDKs; tool-calling + image-content edge cases differ.

eval cadenceconcept

Always mx.eval(model.parameters(), opt.state) after every optimizer step. Otherwise memory and graph debt grow without bound.

Companion tools

huggingface-clitool

Manages the HF cache MLX reads from. huggingface-cli login required for gated models (Gemma).

mlx-datatool

High-throughput data loaders that don't bottleneck on Python. Used in training pipelines.

mlx-graphstool

Graph neural networks built on MLX.

MLX Swifttool

Swift bindings — for shipping MLX models inside iOS / macOS apps. Same array semantics.

Docs & refs

ml-explore.github.io/mlxdocs

Canonical MLX reference: install, examples, full API.

github.com/ml-explore/mlx-lmrepo

mlx-lm source. Examples, training scripts, server, model registry.

github.com/Blaizzy/mlx-vlmrepo

Vision-language models on MLX. Gemma 3 multimodal, Llava, Qwen-VL.

github.com/ml-explore/mlx-examplesrepo

Reference implementations — transformer LM training, MNIST, stable diffusion, whisper. Read these for idiomatic MLX.

mlx-guide (sibling)guide

Narrative how-to companion to this wiki — install, concepts, arrays, nn, training, mlx-lm, Gemma 3 4B walkthrough, quantization, LoRA, server, integrate.