Published on

X Algorithm Part 4: The Ranking Transformer

Authors

X Algorithm Part 4: The Ranking Transformer

This is Part 4 of a 5-part series on the X recommendation algorithm. Part 1 covers the big-picture overview. Part 2 covers the pipeline framework. Part 3 covers candidate sourcing. Part 5 covers scoring, filtering, and the final feed.


Table of Contents


After filtering, ~168 candidates enter the scoring stage. Each one needs a score — a prediction of how likely this user is to engage with this post. That score comes from a Grok-based transformer in Phoenix.

This part covers exactly what that transformer sees, how it processes the input, and what it produces.

The Model Is Tiny

When the codebase says "Grok-based transformer", it doesn't mean the 314B parameter LLM. The ranking model is a miniature transformer configured in phoenix/run_ranker.py:45-53:

TransformerConfig(
    emb_size=128,
    widening_factor=2,
    key_size=64,
    num_q_heads=2,
    num_kv_heads=2,
    num_layers=2,
    attn_output_multiplier=0.125,
)

Two layers, 128-dimensional embeddings, 2 attention heads. By modern standards this is extremely small. The "Grok-based" label refers to the architecture pattern — RMSNorm, RoPE, grouped query attention, gated FFN — not the scale.

The design makes sense for serving: a short fixed-length sequence through a 2-layer transformer is fast on GPU, XLA-compiled, and deterministic. Capacity comes from training on enormous engagement data, not model depth.

What Goes Into the Transformer

The input isn't text. There's no post content, no user profile, no follower counts. The model sees only hash IDs and action types — structured into three types of tokens.

Three Token Types

phoenix/recsys_model.py:62-76

class RecsysBatch(NamedTuple):
    user_hashes: ...               # [B, num_user_hashes]
    history_post_hashes: ...       # [B, S, num_item_hashes]
    history_author_hashes: ...     # [B, S, num_author_hashes]
    history_actions: ...           # [B, S, num_actions] — one-hot engagements
    history_product_surface: ...   # [B, S] — app surface where interaction happened
    candidate_post_hashes: ...     # [B, C, num_item_hashes]
    candidate_author_hashes: ...   # [B, C, num_author_hashes]
    candidate_product_surface: ... # [B, C]

Notice what's missing from candidates: history_actions. Candidates haven't been engaged with yet — that's what we're trying to predict. The asymmetry between history tokens and candidate tokens is intentional and important.

Each hash ID resolves to a learned embedding vector. Two hash functions are used per entity (num_user_hashes=2, num_item_hashes=2, num_author_hashes=2), then combined via concatenation and projection. This was covered in Part 3.

User Token

phoenix/recsys_model.py:79-119

def block_user_reduce(user_hashes, user_embeddings, num_user_hashes, emb_size):
    B = user_embeddings.shape[0]
    # Concatenate 2 hash embeddings: [B, 1, 2D]
    user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * emb_size))
    # Learned projection [2D → D]
    user_embedding = jnp.dot(user_embedding, proj_mat_1)
    return user_embedding  # [B, 1, D]

The user becomes a single token: one vector of dimension D that represents "who this person is" based on their hash ID. The user token is always position 0 in the sequence.

History Tokens

phoenix/recsys_model.py:122-182

Each history item combines four signals into one token:

def block_history_reduce(...):
    # Concatenate: post hashes + author hashes + actions + product_surface
    # Shape: [B, S, (num_item_hashes + num_author_hashes) * D + D + D]
    #      = [B, S, 6D]
    post_author_embedding = jnp.concatenate([
        history_post_embeddings_reshaped,   # [B, S, 2D]
        history_author_embeddings_reshaped, # [B, S, 2D]
        history_actions_embeddings,         # [B, S, D]
        history_product_surface_embeddings, # [B, S, D]
    ], axis=-1)
    # Project [6D → D]
    history_embedding = jnp.dot(post_author_embedding, proj_mat_3)
    return history_embedding  # [B, S, D]

The action encoding deserves attention. Each history item records what action the user took (history_actions is a multi-hot vector). These are converted to signed values before projection:

phoenix/recsys_model.py:314

actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed, action_projection)  # [num_actions, D]

A 1 (action taken) maps to +1. A 0 (action not taken) maps to -1. Both directions carry signal. The learned [num_actions, D] matrix converts the signed vector into an embedding.

product_surface is categorical — where the engagement happened (home timeline, explore, notifications, etc.). It uses a small lookup table with vocab_size=16.

The history sequence length is 32 (history_seq_len=32). Longer histories are truncated upstream by UserActionSeqQueryHydrator.

Candidate Tokens

phoenix/recsys_model.py:185-242

def block_candidate_reduce(...):
    # Concatenate: post hashes + author hashes + product_surface
    # NO actions field — candidates haven't been engaged with yet
    # Shape: [B, C, (num_item_hashes + num_author_hashes) * D + D]
    #      = [B, C, 5D]
    post_author_embedding = jnp.concatenate([
        candidate_post_embeddings_reshaped,   # [B, C, 2D]
        candidate_author_embeddings_reshaped, # [B, C, 2D]
        candidate_product_surface_embeddings, # [B, C, D]
    ], axis=-1)
    # Project [5D → D]
    candidate_embedding = jnp.dot(post_author_embedding, proj_mat_2)
    return candidate_embedding  # [B, C, D]

Candidate tokens are 5D → D, history tokens are 6D → D. The difference is the missing action dimension. A candidate is identified purely by what it is (post ID + author ID + surface), not by any prior engagement.

Assembling the Sequence

phoenix/recsys_model.py:428-436

embeddings = jnp.concatenate(
    [user_embeddings, history_embeddings, candidate_embeddings], axis=1
)
# Shape: [B, 1 + S + C, D]
#      = [B, 1 + 32 + 8, 128]
#      = [B, 41, 128]

candidate_start_offset = 1 + S  # = 33

The full sequence is 41 tokens: 1 user + 32 history + 8 candidates. This is the input to the transformer. The candidate_start_offset (33) is passed alongside the embeddings so the attention mask knows where candidates begin.

Candidate Isolation Masking

This is the most architecturally interesting part of the ranking model.

All 8 candidates are scored in a single forward pass — but they can't attend to each other. Each candidate sees only the user token, the history tokens, and itself.

phoenix/grok.py:39-71

def make_recsys_attn_mask(seq_len, candidate_start_offset):
    # Start with a causal mask (lower triangular)
    causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len)))

    # Zero out candidate-to-candidate attention (bottom-right block)
    attn_mask = causal_mask.at[
        :, :, candidate_start_offset:, candidate_start_offset:
    ].set(0)

    # Add back self-attention for candidates (diagonal only)
    candidate_indices = jnp.arange(candidate_start_offset, seq_len)
    attn_mask = attn_mask.at[
        :, :, candidate_indices, candidate_indices
    ].set(1)

    return attn_mask  # [1, 1, seq_len, seq_len]

The resulting attention pattern for a sequence of length 5 (1 user + 2 history + 2 candidates, for illustration):

          user  hist  hist  cand  cand
user    [  1     0     0     0     0  ]   ← causal: can only see itself
hist1   [  1     1     0     0     0  ]   ← causal: sees user + prior history
hist2   [  1     1     1     0     0  ]   ← causal: sees user + all prior history
cand1   [  1     1     1     1     0  ]   ← sees user + all history + self only
cand2   [  1     1     1     0     1  ]   ← sees user + all history + self only
                         cand1 blocked!

Candidates read from user and history but are invisible to each other.

Why this matters:

Score consistency. If candidate A's score depended on which other candidates were present, the same post could receive a different score in different request contexts. With isolation, every candidate's score is a pure function of (user + history + self). The score for post 5001 is the same whether it competes with 7 others or 167 others.

Cacheability. Because a candidate's output depends only on the user context and itself, candidate representations could in principle be cached across requests for the same user. The isolation mask is what makes this possible.

No positional bias. Without isolation, a candidate at position 33 would have attended to nothing (it's first), while a candidate at position 40 would have attended to 7 others. Their scores would encode position, not just content. Isolation removes the positional artifact entirely.

The Transformer Architecture

Each of the 2 layers is a standard decoder block with a few specific choices:

phoenix/grok.py:443-497

class DecoderLayer:
    def __call__(self, inputs, mask, padding_mask):
        h = inputs
        # Pre-norm: RMSNorm before attention
        attn_output = MHABlock(...)(layer_norm(h), mask)
        h_attn = layer_norm(attn_output.embeddings)
        h += h_attn  # residual

        # Pre-norm: RMSNorm before FFN
        h_dense = DenseBlock(...)(layer_norm(h))
        h_dense = layer_norm(h_dense)
        h += h_dense  # residual

        return DecoderOutput(embeddings=h)

RMSNorm. Pre-normalization before both attention and FFN. RMSNorm (root mean square normalization) is cheaper than LayerNorm — no mean subtraction, just variance scaling.

RoPE. Rotary positional embeddings applied to query and key heads before the dot product. Position information is encoded in the rotation of the embedding vectors rather than added to them.

Grouped Query Attention. With num_q_heads=2 and num_kv_heads=2, this model happens to be standard multi-head attention (GQA with ratio 1:1). The infrastructure supports GQA for scaling to larger models.

Gated FFN. The feed-forward network uses a gating mechanism:

phoenix/grok.py:414-440

class DenseBlock:
    def __call__(self, inputs):
        h_v = Linear(ffn_size)(inputs)          # value path
        h_w1 = gelu(Linear(ffn_size)(inputs))   # gating path
        h_dense = Linear(model_size)(h_w1 * h_v) # gated output
        return h_dense

Two parallel linear projections — one passed through GELU, the other not — multiplied element-wise before the output projection. This is the SwiGLU/GeGLU pattern from modern LLMs.

Attention logit capping. Before softmax, attention logits are clipped using tanh:

phoenix/grok.py:341-343

attn_logits *= self.attn_output_multiplier  # scale by 0.125 first
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)

With attn_output_multiplier=0.125, logits are first scaled down by 8×. Then tanh capping limits them to ±30. Together, this prevents any single position from dominating the attention distribution. Training is more stable when attention doesn't collapse to near-zero softmax distributions.

bfloat16 forward pass, float32 softmax. All activations use bfloat16 for memory and throughput. The softmax in attention is computed in float32 to avoid numerical issues, then cast back:

phoenix/grok.py:354

attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype)  # float32 → bfloat16

Multi-Action Output

After the transformer runs, candidate embeddings are extracted and projected to logits:

phoenix/recsys_model.py:453-474

def __call__(self, batch, recsys_embeddings):
    embeddings, padding_mask, candidate_start_offset = self.build_inputs(...)

    model_output = self.model(embeddings, padding_mask,
                              candidate_start_offset=candidate_start_offset)

    out_embeddings = layer_norm(model_output.embeddings)
    candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]

    unembeddings = self._get_unembedding()  # [D, num_actions]
    logits = jnp.dot(candidate_embeddings, unembeddings)
    # Shape: [B, C, num_actions]

Then in the inference runner, logits become probabilities via sigmoid:

phoenix/runners.py:343

probs = jax.nn.sigmoid(logits)  # [B, C, num_actions]

Each candidate gets 19 probabilities — one per action type:

phoenix/runners.py:202-222

ACTIONS = [
    "favorite_score",        # like
    "reply_score",           # reply
    "repost_score",          # repost/retweet
    "photo_expand_score",    # tap to expand image
    "click_score",           # click through to post
    "profile_click_score",   # click through to author profile
    "vqv_score",             # video quality/view
    "share_score",           # share externally
    "share_via_dm_score",    # share via DM
    "share_via_copy_link_score",
    "dwell_score",           # pause/dwell
    "quote_score",           # quote-repost
    "quoted_click_score",
    "follow_author_score",   # follow the author
    "not_interested_score",  # explicit negative feedback
    "block_author_score",    # block
    "mute_author_score",     # mute
    "report_score",          # report
    "dwell_time",            # time-based dwell signal
]

The list includes both positive engagement signals (like, reply, share) and negative ones (block, mute, report). The model predicts all 19 simultaneously — a single forward pass, one set of outputs per candidate.

What happens with these 19 probabilities is the job of the next layer: the WeightedScorer. Each action gets a signed weight — high positive weight for replies (the code shows 27×), small positive for likes (1×), large negative for blocks and mutes (-74×). The weighted sum becomes the candidate's final score.

That's Part 5.

Design Analysis

Is Recency Encoded?

Sort of — but not for candidates.

For history tokens, recency is implicitly present because the UAS service provides engagements in reverse-chronological order (most recent first). RoPE assigns each history token a positional encoding based on its slot in the sequence: slot 1 = most recent engagement, slot 32 = oldest. So the model can learn that "what the user liked an hour ago" is more predictive than "what they liked last week." The position encoding carries the time signal.

For candidates, recency does not factor in at all. A post from 30 minutes ago and a post from 6 hours ago are both just hash IDs with no timestamp field. The pre-scoring AgeFilter is what enforces a recency cutoff — it removes posts older than a configured threshold before they reach the transformer. By the time ranking happens, every candidate has already passed the age check. The transformer doesn't need to care about recency; the filter handled it.

There's a subtlety with candidate positions. All 8 candidates sit at sequence positions 33–40. RoPE technically encodes them at different positions, so candidate at slot 33 and candidate at slot 40 query history with slightly different rotated keys. In practice this is minor noise — the attention patterns are dominated by content similarity, not the 7-slot position difference. But it's a residual artifact of bending a causal transformer into a parallel scorer.

No Content, Only IDs

The model never reads a post. It sees post ID 7003 as a hash index into a learned embedding table — a 128-dim vector that was updated during training based on how users engaged with that post across the whole dataset.

What the embedding captures is co-engagement: "users who engage with post 7003 tend to also engage with posts 4201, 5887, 9012." Over millions of training steps, the embedding for post 7003 drifts toward the neighborhood of posts it clusters with. The model learns that clustering without knowing any words, topics, or content.

This has an interesting implication: two posts with the same text but different IDs will have completely different scores if their engagement histories differ. The model cares about what users did around a post, not what the post says. A factual article and a conspiracy theory with identical engagement patterns would be ranked identically.

That's by design. It means:

  • No language modeling needed — the system is language-agnostic
  • No feature engineering — topic affinity, writing style, and content quality emerge from engagement data
  • No content pipelines — only engagement events need to be processed and stored

The tradeoff is content blindness. The system cannot distinguish between "users engaged with this because it's compelling" and "users engaged with this because it's outrageous." Trust and safety signals have to come from a separate post-selection filter layer that runs after the feed is assembled.

Candidate Isolation: The Deeper Picture

The isolation mask is best understood as a batching trick.

The conceptually simple way to rank 8 candidates would be: for each candidate, build a sequence [user] + [history] + [this candidate] and run a forward pass. Repeat 8 times. Each candidate gets the same context (user + full history) and produces a score.

That's 8 separate forward passes. Instead, the model does all 8 in one pass by concatenating the candidates into a single sequence of length 41, then masking out cross-candidate attention.

The mask makes the 8-candidate batch equivalent to 8 independent forward passes — as long as no candidate attends to any other. The causal mask already prevents candidates from seeing candidates after them. Zeroing the lower-left candidate block (and restoring the diagonal) takes care of candidates attending to candidates before them.

# Step 1: Start with lower triangular (causal)
[user]  [h1]  [h2]  [c1]  [c2]
  1      0     0     0     0     ← user
  1      1     0     0     0     ← h1
  1      1     1     0     0     ← h2
  1      1     1     1     0     ← c1 sees c1 ✓, doesn't see c2 ✓
  1      1     1     1     1     ← c2 sees c1 ✗ ← problem!

# Step 2: Zero the candidate-to-candidate block
  1      1     1     0     0     ← c1: zero'd candidate column
  1      1     1     0     0     ← c2: zero'd candidate column

# Step 3: Restore diagonal (self-attention)
  1      1     1     1     0     ← c1: can see itself
  1      1     1     0     1     ← c2: can see itself

After these three operations every candidate token sees exactly {user, all history, self} — the same context in every position. The scores are now directly comparable, and a single matrix multiply on the GPU produces all of them simultaneously.

One consequence worth noting: all candidates see the full history, not a prefix of it. The causal structure that applies to user and history tokens (each token sees only earlier tokens) does not extend to candidates. Candidates see all 32 history tokens regardless of which candidate slot they're in. This is correct — the model wants each candidate to be evaluated against the complete engagement context.

Transformer Architecture: Pros and Cons

The architecture choices here reflect real engineering tradeoffs. Taking stock of them:

What works well:

Sequence modeling over the history. Attention can model non-obvious patterns across the 32-item history. If a user liked a thread, then replied to a related post three days later, then liked another post in the same cluster — a transformer can connect those dots. A hand-engineered feature system would need to explicitly define what "cluster" means.

Single forward pass for all candidates. The isolation trick means scoring 8 candidates costs roughly the same as scoring 1. The GPU utilization is high; the latency impact per additional candidate is minimal.

Architecture reuse. Using the same Grok transformer family for retrieval, ranking, and (presumably) content modeling means the team maintains one codebase, one set of training infrastructure, one set of deployment patterns. That compounds over time.

What doesn't:

Cold start is a fundamental gap. A new user has no history → all 32 history slots are zero-padded → the model gets no signal. It returns near-prior probabilities for every candidate. The feed is random. A new post has no engagement → its hash embedding is fresh from initialization → the model's estimate of P(like) is based on author and surface alone. Neither problem has a transformer solution; they require separate fallback systems that aren't in this open-source release.

A 32-item history is short. User interests evolve. A 32-item window that covers the last 24 hours of behavior might totally miss what the user cared about last week. There's no long-term memory mechanism here. If the user is on a three-day reading binge about a niche topic, that context vanishes as soon as 32 newer items displace it.

The model is content-blind by design. The transformer never reads text. It can't reason about what a post says, detect misinformation, understand context, or learn from new topics it hasn't seen in training. If a new news event breaks and people start engaging with posts about it, the model gradually learns the new post embeddings — but that takes training cycles and deployment. It can't zero-shot generalize from the content.

Opaque failures. When the feed surfaces something unexpected, there's no feature to inspect. The answer lives in attention weights across 128-dimensional embedding vectors, which don't map to human-readable explanations. Debugging a ranking mistake requires tracing embedding distances and training data coverage — significantly harder than checking a feature value.

The two-layer, 128-dim design keeps inference fast and the training loop tight. The bet is that engagement data at X's scale compensates for model depth, and empirically that bet seems to have paid off.

What's Next

  • Part 5: The four-scorer chain, the 10 pre-scoring filters, and how the final ranked feed is assembled from the transformer's output.