Large Lambda Model

2025-02-15


Over the last week I decided to write the inference code for GPT-2 after a many year hiatus from Neural Networks. Depending on what primitives you start from, say if you wrote this with JAX or PyTorch, this is quite straight forward, otherwise it is somewhat less so. After lamenting the lack of perfect tensor library in Haskell, I wrote this directly on top of the OpenBLAS bindings in hmatrix. This choice precludes the ability to actually train the model, or even do a single backwards pass without significant effort in writing backprop code that would then be useless, but an old thinkpad CPU is just fast enough to do the forward pass if you get your bits in order. The lack of tensors makes the MultiHead Attention layer a bit of brain teaser, but it’s all just GEMM/GEMV in the end, and it makes this a project that goes from daunting to slowly crystallizing into a nice solution over a few days, ideal.

Preliminaries

If you’d like to implement this yourself, the best place to start is Karpathy’s NanoGPT and llm.c, along with his youtube videos. Also handy are Brendan Bycroft’s LLM Visualizer and a webapp hosting the tokenizer. Ok, now we begin.

We’ve Got Layers

We’re like Onions, we have layers

The GPT-2 Transformer architecture is relatively simple at a high level, with only a few types of layers, arranged in a straight shot down the computation graph. The main complexity is the attention head. This model is fully F32 precision, so we start off by defining some type aliases. While Haskell isn’t dependently typed, hmatrix does have a semi-undocumented interface encoding the size of the matrices/vectors in the types, but unfortunately it is not generic and would not work with F32 without replicating most of the internals, so I chose not to do that, and will annotate sizes with comments instead.

type Token = Int
type V = Vector Float
type M = Matrix Float

Instead of yet another transformer tutorial blog, I will go through this from the lens of reverse engineering the model and translating it into Haskell. I will leave most of the details to the many existing resources. We start then, by examining what types of layers we must contend with, and what weights lie in binary store.

Embedding Layer

The first layer is what takes us from the token into the model proper, the embedding layer, from here on out we do not see Int again until we emerge from the final logits. In GPT-2 we have a vocabulary size of 50257 tokens, and an embedding size of 768. For clarity we will denote N=768. The embedding is not only with respect to the token, but also it’s position in the sequence of tokens, which we might as well call position in time, which has a maximum context size of 1024 tokens. It is important to note that while the tokens themselves are not learned, the embedding weights are. The tokens themselves are generated with the Byte Pair Encoding algorithm, though the vocabulary size is a hyper-parameter.

newtype TokenEmbedding = TokenEmbedding M -- (N, 50257)
newtype PositionEmbedding = PositionEmbedding M -- (N, 1024)

LayerNorm

The next component of our model is the LayerNorm. It has a simple premise, that we should normalize our data (zero mean and unit variance) at various points throughout the model. The weights in this layer are an element-wise affine transformation, ax+b performed after normalization. This is similar to BatchNorm, but normalized along the layer dimension instead of the batch dimension. Since we are only doing forward pass and are tensor-poor, we will assume the batch dimension is one and henceforth ignore it entirely.

data LayerNorm = LayerNorm V V -- (N), (N)

Multi Layer Perceptron

If you have any familiarity with ML, you recognize this, the cheeseburger of Neural Networks. There is a linear layer represented by a matrix and its bias vector, here it scales up before the nonlinearity is applied, then we have another linear layer matrix and bias vector scaling back down to the embedding dimension.

data MLP = MLP M V M V -- (4*N, N), (4*N), (N, 4*N), (N)

Attention

Inside the self attention layer we see a linear transformation N -> 3*N, but this is really an optimization, packing the so called Q, K, and V matrices together in memory. We then split this further, slicing 768 into twelve vectors of length 64, one for each attention head. The additional matrix/vector pair is for a linear layer on the end.

data Attention = Attention M V M V -- (3*N, N), (3*N), (N, N), (N)

Block

We group the previous layers into a Block, as we essentially stack them on top of each other, and then repeat the block twelve times, so it is convenient to conceptually group them.

data Block = Block LayerNorm Attention LayerNorm MLP

GPT

We can then assemble our layers into the complete model, with one more LayerNorm at the end for good measure. Now we are ready to ask how these layers are actually implemented.

data GPT = GPT
  { wpe :: PositionEmbedding,
    wte :: TokenEmbedding,
    blocks :: [Block], -- (12)
    lnf :: LayerNorm
  }

Interlude: Necessary Functions

Before getting into the forward pass, let us define some helper functions. These are things that any modern tensor library would give you, but we will implement them ourselves. There are surprisingly few necessary.

Softmax

This is a venerable softmax, and nothing more, it smoothly turns our vectors into probability distributions.

softmax :: V -> V
softmax v = expv * scalar (1 / sumElements expv)
  where expv = cmap exp v

GELU

The popular choice of nonlinearity at the time was the Gaussian Error Linear Unit, which is a more continuous adaption of the RELU, to avoid getting stuck in the flat region during training. Technically, we are using the tanh approximation of the GELU, which is defined as GELU(x)=x∗Φ(x) where Φ is the CDF of the Gaussian. It seems like the “exact” version is now performant in PyTorch, but the approximation is close enough it doesn’t seem to matter which you use for a single forward pass.

gelu :: V -> V
gelu x = 0.5 * x * (1 + tanh (sqrt (2 / pi) * (x + 0.044715 * x * x * x)))

Tril

This function zeros out the upper triangular portion of the self attention matrix. To be exact it sets them to -Inf which becomes zero after a softmax is applied. The attention matrix encodes the relation between different token positions, and this zeroing corresponds to a token only depending on previous tokens. Much research has been done on the alterations to this matrix, which in theory is completely general and can be put to various purposes.

tril :: Int -> M
tril n = build (n, n) (\i j -> if j > i then -1 / 0 else 0)

Forward Pass

Let’s start by defining a typeclass for our layers, containing the function for the forward pass. This code doesn’t actually generalize, but it’s comfy to do this regardless.

class Layer a where
  forward :: a -> [V] -> [V]

Embedding

We then come to the embedding layer, which does not conform to the typeclass we so hopefully just defined… The important point to note is that the embedding is across two dimensions, the token vocabulary and the token position in time. As we do not have a tensor library, it is convenient to store this as a list of vectors, the size of which cannot grow beyond the context size of 1024, so this should cause no issues. Each element of the list is the embedding of an individual token.

-- the model combines a token indexed and position indexed embedding
embedding :: TokenEmbedding -> PositionEmbedding -> [Token] -> [V]
embedding (TokenEmbedding te) (PositionEmbedding pe) ts =
  zipWith (+) (fmap (sliceColumn te) ts) (toColumns pe)

LayerNorm

As promised, this is just a normalization followed by an affine transformation. The notorious difficulty in implementing LayerNorm and BatchNorm mostly comes down to the backward pass, which we are ignoring. Note that this is an fmap over the input [V], meaning the each token embedding is independent.

instance Layer LayerNorm where
  forward layer = fmap (forwardLN layer)
    where
      forwardLN :: LayerNorm -> V -> V
      forwardLN (LayerNorm w b) x = y
        where
          n = fromIntegral (size x)
          mean = scalar (sumElements x / n)
          cent = x - mean
          varx = sumElements (cent * cent) / n
          fact = scalar (sqrt (varx + 1e-5))
          y = ((x - mean) / fact) * w + b

Attention

We break this up into three parts. First, we apply the QKV linear transformation and break up the result into the individual Q, K, V components, and into the 12 individual heads. Second, we reassemble across the time dimension, so that we can construct the attention matrix for each head, each relating all tokens in time. Third, we flatten everything back out and apply another linear layer, ending back in the same shape we started with.

This splitting and recombining corresponds to reshaping the tensor such that the heads are their own dimension, and then transposing it with the time (token) dimension. We do not have this capability, so we must make do, and this is the trickiest part of the code by far.

-- the first part of the attention head is a linear layer.
-- Q,K,V weights and heads are combined and we have to take them apart here.
attnAtToken :: Attention -> V -> ([V], [V], [V])
attnAtToken (Attention w b _ _) x = (qh, kh, vh)
  where
    y = (w #> x) + b
    -- split apart into Q, K, V components
    (q, k, v) = case takesV [768, 768, 768] y of
      [x1, x2, x3] -> (x1, x2, x3)
      _ -> error "QKV could not be split"
    -- split into individual heads
    qh = takesV (replicate 12 64) q
    kh = takesV (replicate 12 64) k
    vh = takesV (replicate 12 64) v

-- this is the actual attention part where we construct the attention matrix.
attnHead :: (M, M, M) -> M
attnHead (q, k, v) = z
  where
    attnMatrix = tr q <> k * scalar (1 / 8) -- 1 / sqrt (size k)
    -- mask the upper right triangular to -inf (becomes 0 in softmax)
    attnMasked = tril (rows attnMatrix) + attnMatrix
    -- no tensor library means we have to do this kinda stuff
    attnSoftmax = fromRows (fmap softmax (toRows attnMasked))
    z = attnSoftmax <> tr v

instance Layer Attention where
  forward at@(Attention _ _ w b) xs = z
    where
      (q, k, v) = unzip3 (fmap (attnAtToken at) xs)
      qh = fmap fromColumns (transpose q)
      kh = fmap fromColumns (transpose k)
      vh = fmap fromColumns (transpose v)
      lm = fmap attnHead (zip3 qh kh vh)
      y = fmap vjoin (transpose (fmap toRows lm))
      z = fmap ((+ b) . (w #>)) y

Multi Layer Perceptron

Now we are back to classical Neural Networks, and it feels easy in comparison.

instance Layer MLP where
  forward (MLP wfc bfc wproj bproj) x = x3
    where
      x1 = fmap ((+ bfc) . (wfc #>)) x
      x2 = fmap gelu x1
      x3 = fmap ((+ bproj) . (wproj #>)) x2

Block Layer

Finally we can assemble the Block. Here there is only one thing of note, the pass-through, usually called a residual or skip connection (as in ResNet), a trick that was discovered when looking for ways to successfully train deeper networks.

instance Layer Block where
  forward (Block l1 at l2 mp) xs = x4
    where
      x1 = forward l1 xs
      x2 = zipWith (+) xs (forward at x1)
      x3 = forward l2 x2
      x4 = zipWith (+) x2 (forward mp x3)

GPT

Putting it all together now, we embed, apply the blocks in sequence, just one more LayerNorm, and then we apply the token embedding to output logits which we will use to sample the next token in the sequence. Since we are doing forward pass only, there is no cross entropy or loss at the end of course.

forwardModel :: GPT -> [Token] -> [V]
forwardModel model tokens = x3
  where
    TokenEmbedding wtew = wte model
    emb = embedding (wte model) (wpe model) tokens
    x1 = foldr forward emb (blocks model)
    x2 = forward (lnf model) x1
    x3 = fmap (tr wtew #>) x2

The main trick in implementing such a thing is taking apart the reference implementation and inspecting it every single step of the way, the standard mechanical acts of reverse engineering.

They Call Me The Sampler

To actually get something useful from the model, we must take it’s output predictions and sample from them. This is another scenario where the lack of surrounding ecosystem in Haskell leaves us to our own devices. Luckily you can make a usable sampler out of leftover bits, and I will show you how.

Maximum Sampler

There is of course the cop out sampler, to simply take the highest scored token at every step. The results are quite bad, in fact it is rather instructive as to the importance of a good sampler, though this does work to test your model is working at all. We lift this into the IO monad only to be consistent with the following sampler which uses Random IO.

sampleMax :: V -> IO Token
sampleMax x = return $ snd $ maximumBy (comparing fst) (zip (toList x) [0 ..])

Top-K Uniform Sampler

We will use the approach given in the reference implementation, to limit ourselves to the top K (they use 200, I chose 50) values and sample them with their softmax probabilities. How do we sample from this list of probabilities? There is a nice way of doing just this, as the sum of probabilities must sum to one, we can associate each probability to a disjoint interval contained in (0,1). The order is unimportant, we take the order we are given, and we construct the cumulative probabilities, which correspond to the right endpoints of these intervals. We can then sample with a uniform random sample from the unit interval, associate it with the greatest lower bound in our cumulative probabilites, it will correspond to a sample from our original distribution. Neat!

topK :: V -> V
topK v = fromList (map f (toList v))
  where
    k = sortBy (comparing Down) (toList v) !! 50
    f x = if x > k then x / 2 else -1 / 0 -- here 2 is the "temperature"
    
sampleLogits :: V -> IO Int
sampleLogits v = do
  r <- randomRIO (0.0, 1.0)
  return $ findIndex r (scanl1 (+) (toList $ softmax $ topK v))
  where
    findIndex r cumProbs = length (takeWhile (<= r) cumProbs)

We are now ready to start generating some fresh tokens.

Running The Model

Using all the pieces we’ve assembled so far, we can run the model with some straight forward IO event loop code.

run :: GPT -> TokenMap -> Natural -> [Token] -> IO [Token]
run model tm iter tokens = do
  next <- sampleLogits $ last $ forwardModel model tokens
  TIO.putStr (token tm [next])
  if iter == 0 || next == 50256
    -- this is short enough that end of list append is fine 
    then return (tokens ++ [next])
    else run model tm (iter - 1) (tokens ++ [next])
    
main :: IO ()
main = do
  hSetBuffering stdout NoBuffering
  putStrLn "λλμ: Now Loading..."
  tensors <- readModel "model.safetensors"
  vocab <- readVocab "vocab.json"
  let model = case tensors of Right gpt -> gpt; Left err -> error err
  let tokenMap = case vocab of Just tm -> tm; Nothing -> error "Couldn't parse vocab"
  putStr "Hello, I am"
  -- Tokens for "Hello, I am"
  generate <- run model tokenMap 50 [15496, 11, 314, 716]
  print generate 

Here is an example output, with typical small model weirdness. Note this is the smallest 124M parameter GPT-2 model.

λλμ: Now Loading...
Hello, I am not quite sure what this "Theater at Heart of Harry Potter" book of essays to which all children are prone must say to all adults; they, the characters in whom the books "hates him."

That’s it, we wrote a forward pass for GPT-2!

But wait, you might say, where did that readModel function come from, or the token function. For the tokens, I am simply using the vocab.json file provided with the model weights. This is not handled correctly, possibly due to Aeson disagreeing with encoding specifics of the unicode keys in the JSON, so I will not include it here. I did not even attempt the token encoder. Nobody likes the tokenizer!

For the model loading, I chose to parse the model.safetensors format that Huggingface provides. The details are tedious, so they have been relegated to the appendix.

Performance Considerations

The main performance issue I had was really my own blunder, in my haste to prototype the inference I neglected the loader. hmatrix does not come with a way to load a vector directly from a ByteString, so we must do some work with the lower level memory interfaces. If one wishes to attempt this themselves, the critical point is to map the vectors directly, else suffer the consequences of parsing through intermediaries. Using the available tools in hmatrix and Data.Binary.Get the obvious solution is parsing a bytestring to a list of floats, then to a vector. This is incredibly slow. Luckily the FFI in Haskell is quite nice, and we can index into the (strict!) bytestring with a pointer that can then be cast to the FFI Storable used by Vectors, without additional allocation. This gets the loading time down to a few seconds.

In terms of inference performance, hmatrix does an admirable job, and BLAS parallelizes well enough to saturate all 8 of my cores without needing to use something like parMap. The main slowdown is the quadratic scaling of self attention, a fundamental time complexity issue that can be somewhat improved by things like FlashAttention and custom kernels. I’ll do none of those things here. I’m not sure there is much benefit in trying to optimize this further, as the hmatrix primitives are not really the right foundation for this work, and something closer to an array combinator DSL like accellerate or futhark would be a better direction, though the various options all have their drawbacks. There is also the question of training, and we would need to think about something like backprop.

Appendix : Data Loading

For the curious, I’ve included the full loading code. The safetensors format is quite simple, a leading uint64 encoding the metadata length, followed by said metadata, which is just JSON. The remainder of the file is the tensors in binary form. This JSON contains a manifest of each layer, and their relative indices in the file, which we can use to load them.

module Loader where

import Data.Aeson (FromJSON, ToJSON, Value, eitherDecode, withObject, (.:))
import Data.Aeson.Encode.Pretty (encodePretty)
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import Data.Aeson.Types (Parser, parseEither)
import Data.Bifunctor (bimap)
import Data.Binary.Get (getWord64le, runGetOrFail)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Storable as VS
import Data.Word (Word64)
import Foreign.ForeignPtr (castForeignPtr)
import Foreign.Storable (Storable, sizeOf)
import GHC.Generics (Generic)
import Model
import Numeric.LinearAlgebra (reshape, tr)
import Prelude hiding ((<>))

-- simple sum type so we can load either vec or mat
-- I could probably use the generic Container from hmatrix but this is easy
data Tensor = T1 V | T2 M

-- generate a keymap based on the safetensor metadata
type TensorMap = KM.KeyMap Tensor

-- metadata for an individual tensor (safetensor format)
data TensorMetadata = TensorMetadata
  { dtype :: String,
    shape :: [Int],
    dataOffsets :: (Int, Int)
  }
  deriving (Show, Generic, FromJSON, ToJSON)

-- entire safetensors file including unmapped raw tensor data
data SafeTensors = SafeTensors
  { metadata :: KM.KeyMap TensorMetadata,
    binaryData :: BS.ByteString
  }

-- we don't want to show the binary data, might as well have a pretty printer
instance Show SafeTensors where
  show safetensors = show $ encodePretty (metadata safetensors)

-- Parse tensor metadata from JSON segment of file
parseTensorMetadata :: Value -> Parser TensorMetadata
parseTensorMetadata = withObject "TensorMetadata" $ \obj -> do
  mdtype <- obj .: "dtype"
  mshape <- obj .: "shape"
  (i, j) <- obj .: "data_offsets"
  return
    ( TensorMetadata
        { shape = mshape,
          dataOffsets = (i, j),
          dtype = mdtype
        }
    )

parseTensors :: BL.ByteString -> Either String SafeTensors
parseTensors bs = do
  -- the first 8 bytes are an uint specifiying length of JSON segment
  numBytes <- parseWord64 (BL.take 8 bs)
  -- the next N bytes can be decoded directly with aeson
  obj <- eitherDecode (BL.take (fromIntegral numBytes) (BL.drop 8 bs))
  -- this is the one key that isn't a tensor, easiest just to remove it
  let tensors = KM.delete (K.fromString "__metadata__") obj
  -- parse tensor metadata objects into our metadata type
  x <- mapM (parseEither parseTensorMetadata) tensors
  -- return metadata keymap along with remaining raw bytes containing tensor data
  return (SafeTensors x (BS.toStrict (BL.drop (8 + fromIntegral numBytes) bs)))

-- parse a Word64 from the head of the file (encodes length of JSON segment)
parseWord64 :: BL.ByteString -> Either String Word64
parseWord64 bs = case runGetOrFail getWord64le bs of
  Right (_, _, w) -> Right w
  Left (_, _, s) -> Left ("Error reading leading uint64: " ++ s)

-- https://stackoverflow.com/questions/18682527/how-to-convert-between-bytestring-and-storable-vector
byteStringToVector :: (Storable a) => BS.ByteString -> VS.Vector a
byteStringToVector bs = vec
  where
    vec = VS.unsafeFromForeignPtr (castForeignPtr fptr) (scale off) (scale len)
    (fptr, off, len) = BS.toForeignPtr bs
    scale = (`div` sizeOfElem vec)
    sizeOfElem vect = sizeOf (undefined `asTypeOf` VS.head vect)

bytesToTensor :: BS.ByteString -> TensorMetadata -> Either String Tensor
bytesToTensor bs meta = case shape meta of
  [n] -> if VG.length vec == n then Right (T1 vec) else errmsg
  [n, m] -> if VG.length vec == n * m then Right (T2 (reshape m vec)) else errmsg
  [1, 1, n, m] -> if VG.length vec == n * m then Right (T2 (reshape m vec)) else errmsg
  _ -> errmsg
  where
    (startpos, endpos) = bimap fromIntegral fromIntegral (dataOffsets meta)
    errmsg = Left ("Wrong size while reading " ++ show meta)
    -- it would maybe be better to load them "in order" with splitAt but
    -- the loading is fast enough with this now that the BS is cast directly
    vec = byteStringToVector (BS.drop startpos (BS.take endpos bs))

-- getting layer weights is straight forward. some matrices need to be transposed.

getMat :: TensorMap -> String -> Either String M
getMat tm s = case KM.lookup (K.fromString s) tm of
  (Just (T2 m)) -> Right m
  _ -> Left ("Error loading " ++ s)

getVec :: TensorMap -> String -> Either String V
getVec tm s = case KM.lookup (K.fromString s) tm of
  (Just (T1 v)) -> Right v
  _ -> Left ("Error loading " ++ s)

getTELayer :: TensorMap -> Either String TokenEmbedding
getTELayer tm = do
  m <- getMat tm "wte.weight"
  return (TokenEmbedding (tr m))

getPELayer :: TensorMap -> Either String PositionEmbedding
getPELayer tm = do
  m <- getMat tm "wpe.weight"
  return  (PositionEmbedding (tr m))

getLayerNorm :: TensorMap -> String -> Either String LayerNorm
getLayerNorm tm s = do
  w <- getVec tm (s ++ ".weight")
  b <- getVec tm (s ++ ".bias")
  return (LayerNorm w b)

getAttention :: TensorMap -> String -> Either String Attention
getAttention tm layer = do
  aw <- getMat tm (layer ++ ".attn.c_attn.weight")
  ab <- getVec tm (layer ++ ".attn.c_attn.bias")
  pw <- getMat tm (layer ++ ".attn.c_proj.weight")
  pb <- getVec tm (layer ++ ".attn.c_proj.bias")
  return (Attention (tr aw) ab (tr pw) pb)

getMLP :: TensorMap -> String -> Either String MLP
getMLP tm layer = do
  aw <- getMat tm (layer ++ ".mlp.c_fc.weight")
  ab <- getVec tm (layer ++ ".mlp.c_fc.bias")
  pw <- getMat tm (layer ++ ".mlp.c_proj.weight")
  pb <- getVec tm (layer ++ ".mlp.c_proj.bias")
  return (MLP (tr aw) ab (tr pw) pb)

getBlock :: TensorMap -> Int -> Either String Block
getBlock tm i = do
  let prefix = "h." ++ show i
  le1 <- getLayerNorm tm (prefix ++ ".ln_1")
  le2 <- getLayerNorm tm (prefix ++ ".ln_2")
  at <- getAttention tm prefix
  mp <- getMLP tm prefix
  return (Block le1 at le2 mp)

constructModel :: TensorMap -> Either String GPT
constructModel tm = do
  pe <- getPELayer tm
  te <- getTELayer tm
  block <- mapM (getBlock tm) [11, 10 .. 0]
  ln <- getLayerNorm tm "ln_f"
  return (GPT pe te block ln)

getTensorMap :: SafeTensors -> Either String TensorMap
getTensorMap ten = mapM (bytesToTensor (binaryData ten)) (metadata ten)

parseModel :: BL.ByteString -> Either String GPT
parseModel bytes = do
  safeTensors <- parseTensors bytes
  tensorMap <- getTensorMap safeTensors
  constructModel tensorMap

readModel :: String -> IO (Either String GPT)
readModel filePath = do
  contents <- BL.readFile filePath
  return (parseModel contents)