Large Lambda Model
2025-02-15
Over the last week I decided to write the inference code for GPT-2 after a many year hiatus from Neural Networks.
Depending on what primitives you start from, say if you wrote this with JAX or PyTorch, this is quite straight forward, otherwise it is somewhat less so.
After lamenting the lack of perfect tensor library in Haskell, I wrote this directly on top of the OpenBLAS bindings in hmatrix
.
This choice precludes the ability to actually train the model, or even do a single backwards pass without significant effort in writing backprop code that would then be useless, but an old thinkpad CPU is just fast enough to do the forward pass if you get your bits in order.
The lack of tensors makes the MultiHead Attention layer a bit of brain teaser, but it’s all just GEMM/GEMV
in the end, and it makes this a project that goes from daunting to slowly crystallizing into a nice solution over a few days, ideal.
Preliminaries
If you’d like to implement this yourself, the best place to start is Karpathy’s NanoGPT and llm.c, along with his youtube videos. Also handy are Brendan Bycroft’s LLM Visualizer and a webapp hosting the tokenizer. Ok, now we begin.
We’ve Got Layers
The GPT-2 Transformer architecture is relatively simple at a high level, with only a few types of layers, arranged in a straight shot down the computation graph.
The main complexity is the attention head.
This model is fully F32 precision, so we start off by defining some type aliases.
While Haskell isn’t dependently typed, hmatrix
does have a semi-undocumented interface encoding the size of the matrices/vectors in the types, but unfortunately it is not generic and would not work with F32 without replicating most of the internals, so I chose not to do that, and will annotate sizes with comments instead.
type Token = Int
type V = Vector Float
type M = Matrix Float
Instead of yet another transformer tutorial blog, I will go through this from the lens of reverse engineering the model and translating it into Haskell. I will leave most of the details to the many existing resources. We start then, by examining what types of layers we must contend with, and what weights lie in binary store.
Embedding Layer
The first layer is what takes us from the token into the model proper, the embedding layer, from here on out we do not see Int
again until we emerge from the final logits.
In GPT-2 we have a vocabulary size of 50257 tokens, and an embedding size of 768.
For clarity we will denote N=768
.
The embedding is not only with respect to the token, but also it’s position in the sequence of tokens, which we might as well call position in time, which has a maximum context size of 1024 tokens.
It is important to note that while the tokens themselves are not learned, the embedding weights are.
The tokens themselves are generated with the Byte Pair Encoding algorithm, though the vocabulary size is a hyper-parameter.
newtype TokenEmbedding = TokenEmbedding M -- (N, 50257)
newtype PositionEmbedding = PositionEmbedding M -- (N, 1024)
LayerNorm
The next component of our model is the LayerNorm.
It has a simple premise, that we should normalize our data (zero mean and unit variance) at various points throughout the model.
The weights in this layer are an element-wise affine transformation, ax+b
performed after normalization.
This is similar to BatchNorm, but normalized along the layer dimension instead of the batch dimension.
Since we are only doing forward pass and are tensor-poor, we will assume the batch dimension is one and henceforth ignore it entirely.
data LayerNorm = LayerNorm V V -- (N), (N)
Multi Layer Perceptron
If you have any familiarity with ML, you recognize this, the cheeseburger of Neural Networks. There is a linear layer represented by a matrix and its bias vector, here it scales up before the nonlinearity is applied, then we have another linear layer matrix and bias vector scaling back down to the embedding dimension.
data MLP = MLP M V M V -- (4*N, N), (4*N), (N, 4*N), (N)
Attention
Inside the self attention layer we see a linear transformation N -> 3*N
, but this is really an optimization, packing the so called Q, K, and V matrices together in memory.
We then split this further, slicing 768
into twelve vectors of length 64
, one for each attention head.
The additional matrix/vector pair is for a linear layer on the end.
data Attention = Attention M V M V -- (3*N, N), (3*N), (N, N), (N)
Block
We group the previous layers into a Block, as we essentially stack them on top of each other, and then repeat the block twelve times, so it is convenient to conceptually group them.
data Block = Block LayerNorm Attention LayerNorm MLP
GPT
We can then assemble our layers into the complete model, with one more LayerNorm at the end for good measure. Now we are ready to ask how these layers are actually implemented.
data GPT = GPT
wpe :: PositionEmbedding,
{ wte :: TokenEmbedding,
blocks :: [Block], -- (12)
lnf :: LayerNorm
}
Interlude: Necessary Functions
Before getting into the forward pass, let us define some helper functions. These are things that any modern tensor library would give you, but we will implement them ourselves. There are surprisingly few necessary.
Softmax
This is a venerable softmax, and nothing more, it smoothly turns our vectors into probability distributions.
softmax :: V -> V
= expv * scalar (1 / sumElements expv)
softmax v where expv = cmap exp v
GELU
The popular choice of nonlinearity at the time was the Gaussian Error Linear Unit, which is a more continuous adaption of the RELU, to avoid getting stuck in the flat region during training.
Technically, we are using the tanh
approximation of the GELU, which is defined as GELU(x)=x∗Φ(x)
where Φ
is the CDF of the Gaussian.
It seems like the “exact” version is now performant in PyTorch, but the approximation is close enough it doesn’t seem to matter which you use for a single forward pass.
gelu :: V -> V
= 0.5 * x * (1 + tanh (sqrt (2 / pi) * (x + 0.044715 * x * x * x))) gelu x
Tril
This function zeros out the upper triangular portion of the self attention matrix.
To be exact it sets them to -Inf
which becomes zero after a softmax is applied.
The attention matrix encodes the relation between different token positions, and this zeroing corresponds to a token only depending on previous tokens.
Much research has been done on the alterations to this matrix, which in theory is completely general and can be put to various purposes.
tril :: Int -> M
= build (n, n) (\i j -> if j > i then -1 / 0 else 0) tril n
Forward Pass
Let’s start by defining a typeclass for our layers, containing the function for the forward pass. This code doesn’t actually generalize, but it’s comfy to do this regardless.
class Layer a where
forward :: a -> [V] -> [V]
Embedding
We then come to the embedding layer, which does not conform to the typeclass we so hopefully just defined… The important point to note is that the embedding is across two dimensions, the token vocabulary and the token position in time. As we do not have a tensor library, it is convenient to store this as a list of vectors, the size of which cannot grow beyond the context size of 1024, so this should cause no issues. Each element of the list is the embedding of an individual token.
-- the model combines a token indexed and position indexed embedding
embedding :: TokenEmbedding -> PositionEmbedding -> [Token] -> [V]
TokenEmbedding te) (PositionEmbedding pe) ts =
embedding (zipWith (+) (fmap (sliceColumn te) ts) (toColumns pe)
LayerNorm
As promised, this is just a normalization followed by an affine transformation.
The notorious difficulty in implementing LayerNorm and BatchNorm mostly comes down to the backward pass, which we are ignoring.
Note that this is an fmap
over the input [V]
, meaning the each token embedding is independent.
instance Layer LayerNorm where
= fmap (forwardLN layer)
forward layer where
forwardLN :: LayerNorm -> V -> V
LayerNorm w b) x = y
forwardLN (where
= fromIntegral (size x)
n = scalar (sumElements x / n)
mean = x - mean
cent = sumElements (cent * cent) / n
varx = scalar (sqrt (varx + 1e-5))
fact = ((x - mean) / fact) * w + b y
Attention
We break this up into three parts. First, we apply the QKV linear transformation and break up the result into the individual Q, K, V components, and into the 12 individual heads. Second, we reassemble across the time dimension, so that we can construct the attention matrix for each head, each relating all tokens in time. Third, we flatten everything back out and apply another linear layer, ending back in the same shape we started with.
This splitting and recombining corresponds to reshaping the tensor such that the heads are their own dimension, and then transposing it with the time (token) dimension. We do not have this capability, so we must make do, and this is the trickiest part of the code by far.
-- the first part of the attention head is a linear layer.
-- Q,K,V weights and heads are combined and we have to take them apart here.
attnAtToken :: Attention -> V -> ([V], [V], [V])
Attention w b _ _) x = (qh, kh, vh)
attnAtToken (where
= (w #> x) + b
y -- split apart into Q, K, V components
= case takesV [768, 768, 768] y of
(q, k, v) -> (x1, x2, x3)
[x1, x2, x3] -> error "QKV could not be split"
_ -- split into individual heads
= takesV (replicate 12 64) q
qh = takesV (replicate 12 64) k
kh = takesV (replicate 12 64) v
vh
-- this is the actual attention part where we construct the attention matrix.
attnHead :: (M, M, M) -> M
= z
attnHead (q, k, v) where
= tr q <> k * scalar (1 / 8) -- 1 / sqrt (size k)
attnMatrix -- mask the upper right triangular to -inf (becomes 0 in softmax)
= tril (rows attnMatrix) + attnMatrix
attnMasked -- no tensor library means we have to do this kinda stuff
= fromRows (fmap softmax (toRows attnMasked))
attnSoftmax = attnSoftmax <> tr v
z
instance Layer Attention where
@(Attention _ _ w b) xs = z
forward atwhere
= unzip3 (fmap (attnAtToken at) xs)
(q, k, v) = fmap fromColumns (transpose q)
qh = fmap fromColumns (transpose k)
kh = fmap fromColumns (transpose v)
vh = fmap attnHead (zip3 qh kh vh)
lm = fmap vjoin (transpose (fmap toRows lm))
y = fmap ((+ b) . (w #>)) y z
Multi Layer Perceptron
Now we are back to classical Neural Networks, and it feels easy in comparison.
instance Layer MLP where
MLP wfc bfc wproj bproj) x = x3
forward (where
= fmap ((+ bfc) . (wfc #>)) x
x1 = fmap gelu x1
x2 = fmap ((+ bproj) . (wproj #>)) x2 x3
Block Layer
Finally we can assemble the Block. Here there is only one thing of note, the pass-through, usually called a residual or skip connection (as in ResNet), a trick that was discovered when looking for ways to successfully train deeper networks.
instance Layer Block where
Block l1 at l2 mp) xs = x4
forward (where
= forward l1 xs
x1 = zipWith (+) xs (forward at x1)
x2 = forward l2 x2
x3 = zipWith (+) x2 (forward mp x3) x4
GPT
Putting it all together now, we embed, apply the blocks in sequence, just one more LayerNorm, and then we apply the token embedding to output logits which we will use to sample the next token in the sequence. Since we are doing forward pass only, there is no cross entropy or loss at the end of course.
forwardModel :: GPT -> [Token] -> [V]
= x3
forwardModel model tokens where
TokenEmbedding wtew = wte model
= embedding (wte model) (wpe model) tokens
emb = foldr forward emb (blocks model)
x1 = forward (lnf model) x1
x2 = fmap (tr wtew #>) x2 x3
The main trick in implementing such a thing is taking apart the reference implementation and inspecting it every single step of the way, the standard mechanical acts of reverse engineering.
They Call Me The Sampler
To actually get something useful from the model, we must take it’s output predictions and sample from them. This is another scenario where the lack of surrounding ecosystem in Haskell leaves us to our own devices. Luckily you can make a usable sampler out of leftover bits, and I will show you how.
Maximum Sampler
There is of course the cop out sampler, to simply take the highest scored token at every step. The results are quite bad, in fact it is rather instructive as to the importance of a good sampler, though this does work to test your model is working at all. We lift this into the IO monad only to be consistent with the following sampler which uses Random IO.
sampleMax :: V -> IO Token
= return $ snd $ maximumBy (comparing fst) (zip (toList x) [0 ..]) sampleMax x
Top-K Uniform Sampler
We will use the approach given in the reference implementation, to limit ourselves to the top K (they use 200, I chose 50) values and sample them with their softmax probabilities.
How do we sample from this list of probabilities?
There is a nice way of doing just this, as the sum of probabilities must sum to one, we can associate each probability to a disjoint interval contained in (0,1)
.
The order is unimportant, we take the order we are given, and we construct the cumulative probabilities, which correspond to the right endpoints of these intervals.
We can then sample with a uniform random sample from the unit interval, associate it with the greatest lower bound in our cumulative probabilites, it will correspond to a sample from our original distribution.
Neat!
topK :: V -> V
= fromList (map f (toList v))
topK v where
= sortBy (comparing Down) (toList v) !! 50
k = if x > k then x / 2 else -1 / 0 -- here 2 is the "temperature"
f x
sampleLogits :: V -> IO Int
= do
sampleLogits v <- randomRIO (0.0, 1.0)
r return $ findIndex r (scanl1 (+) (toList $ softmax $ topK v))
where
= length (takeWhile (<= r) cumProbs) findIndex r cumProbs
We are now ready to start generating some fresh tokens.
Running The Model
Using all the pieces we’ve assembled so far, we can run the model with some straight forward IO event loop code.
run :: GPT -> TokenMap -> Natural -> [Token] -> IO [Token]
= do
run model tm iter tokens <- sampleLogits $ last $ forwardModel model tokens
next
TIO.putStr (token tm [next])if iter == 0 || next == 50256
-- this is short enough that end of list append is fine
then return (tokens ++ [next])
else run model tm (iter - 1) (tokens ++ [next])
main :: IO ()
= do
main NoBuffering
hSetBuffering stdout putStrLn "λλμ: Now Loading..."
<- readModel "model.safetensors"
tensors <- readVocab "vocab.json"
vocab let model = case tensors of Right gpt -> gpt; Left err -> error err
let tokenMap = case vocab of Just tm -> tm; Nothing -> error "Couldn't parse vocab"
putStr "Hello, I am"
-- Tokens for "Hello, I am"
<- run model tokenMap 50 [15496, 11, 314, 716]
generate print generate
Here is an example output, with typical small model weirdness. Note this is the smallest 124M parameter GPT-2 model.
λλμ: Now Loading... Hello, I am not quite sure what this "Theater at Heart of Harry Potter" book of essays to which all children are prone must say to all adults; they, the characters in whom the books "hates him."
That’s it, we wrote a forward pass for GPT-2!
But wait, you might say, where did that readModel
function come from, or the token
function.
For the tokens, I am simply using the vocab.json
file provided with the model weights.
This is not handled correctly, possibly due to Aeson disagreeing with encoding specifics of the unicode keys in the JSON, so I will not include it here.
I did not even attempt the token encoder.
Nobody likes the tokenizer!
For the model loading, I chose to parse the model.safetensors
format that Huggingface provides.
The details are tedious, so they have been relegated to the appendix.
Performance Considerations
The main performance issue I had was really my own blunder, in my haste to prototype the inference I neglected the loader.
hmatrix
does not come with a way to load a vector directly from a ByteString, so we must do some work with the lower level memory interfaces.
If one wishes to attempt this themselves, the critical point is to map the vectors directly, else suffer the consequences of parsing through intermediaries.
Using the available tools in hmatrix
and Data.Binary.Get
the obvious solution is parsing a bytestring to a list of floats, then to a vector.
This is incredibly slow.
Luckily the FFI in Haskell is quite nice, and we can index into the (strict!) bytestring with a pointer that can then be cast to the FFI Storable
used by Vectors, without additional allocation.
This gets the loading time down to a few seconds.
In terms of inference performance, hmatrix
does an admirable job, and BLAS parallelizes well enough to saturate all 8 of my cores without needing to use something like parMap
.
The main slowdown is the quadratic scaling of self attention, a fundamental time complexity issue that can be somewhat improved by things like FlashAttention and custom kernels.
I’ll do none of those things here.
I’m not sure there is much benefit in trying to optimize this further, as the hmatrix
primitives are not really the right foundation for this work, and something closer to an array combinator DSL like accellerate
or futhark
would be a better direction, though the various options all have their drawbacks.
There is also the question of training, and we would need to think about something like backprop.
Appendix : Data Loading
For the curious, I’ve included the full loading code.
The safetensors
format is quite simple, a leading uint64
encoding the metadata length, followed by said metadata, which is just JSON.
The remainder of the file is the tensors in binary form.
This JSON contains a manifest of each layer, and their relative indices in the file, which we can use to load them.
module Loader where
import Data.Aeson (FromJSON, ToJSON, Value, eitherDecode, withObject, (.:))
import Data.Aeson.Encode.Pretty (encodePretty)
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import Data.Aeson.Types (Parser, parseEither)
import Data.Bifunctor (bimap)
import Data.Binary.Get (getWord64le, runGetOrFail)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Storable as VS
import Data.Word (Word64)
import Foreign.ForeignPtr (castForeignPtr)
import Foreign.Storable (Storable, sizeOf)
import GHC.Generics (Generic)
import Model
import Numeric.LinearAlgebra (reshape, tr)
import Prelude hiding ((<>))
-- simple sum type so we can load either vec or mat
-- I could probably use the generic Container from hmatrix but this is easy
data Tensor = T1 V | T2 M
-- generate a keymap based on the safetensor metadata
type TensorMap = KM.KeyMap Tensor
-- metadata for an individual tensor (safetensor format)
data TensorMetadata = TensorMetadata
dtype :: String,
{ shape :: [Int],
dataOffsets :: (Int, Int)
}deriving (Show, Generic, FromJSON, ToJSON)
-- entire safetensors file including unmapped raw tensor data
data SafeTensors = SafeTensors
metadata :: KM.KeyMap TensorMetadata,
{ binaryData :: BS.ByteString
}
-- we don't want to show the binary data, might as well have a pretty printer
instance Show SafeTensors where
show safetensors = show $ encodePretty (metadata safetensors)
-- Parse tensor metadata from JSON segment of file
parseTensorMetadata :: Value -> Parser TensorMetadata
= withObject "TensorMetadata" $ \obj -> do
parseTensorMetadata <- obj .: "dtype"
mdtype <- obj .: "shape"
mshape <- obj .: "data_offsets"
(i, j) return
TensorMetadata
( = mshape,
{ shape = (i, j),
dataOffsets = mdtype
dtype
}
)
parseTensors :: BL.ByteString -> Either String SafeTensors
= do
parseTensors bs -- the first 8 bytes are an uint specifiying length of JSON segment
<- parseWord64 (BL.take 8 bs)
numBytes -- the next N bytes can be decoded directly with aeson
<- eitherDecode (BL.take (fromIntegral numBytes) (BL.drop 8 bs))
obj -- this is the one key that isn't a tensor, easiest just to remove it
let tensors = KM.delete (K.fromString "__metadata__") obj
-- parse tensor metadata objects into our metadata type
<- mapM (parseEither parseTensorMetadata) tensors
x -- return metadata keymap along with remaining raw bytes containing tensor data
return (SafeTensors x (BS.toStrict (BL.drop (8 + fromIntegral numBytes) bs)))
-- parse a Word64 from the head of the file (encodes length of JSON segment)
parseWord64 :: BL.ByteString -> Either String Word64
= case runGetOrFail getWord64le bs of
parseWord64 bs Right (_, _, w) -> Right w
Left (_, _, s) -> Left ("Error reading leading uint64: " ++ s)
-- https://stackoverflow.com/questions/18682527/how-to-convert-between-bytestring-and-storable-vector
byteStringToVector :: (Storable a) => BS.ByteString -> VS.Vector a
= vec
byteStringToVector bs where
= VS.unsafeFromForeignPtr (castForeignPtr fptr) (scale off) (scale len)
vec = BS.toForeignPtr bs
(fptr, off, len) = (`div` sizeOfElem vec)
scale = sizeOf (undefined `asTypeOf` VS.head vect)
sizeOfElem vect
bytesToTensor :: BS.ByteString -> TensorMetadata -> Either String Tensor
= case shape meta of
bytesToTensor bs meta -> if VG.length vec == n then Right (T1 vec) else errmsg
[n] -> if VG.length vec == n * m then Right (T2 (reshape m vec)) else errmsg
[n, m] 1, 1, n, m] -> if VG.length vec == n * m then Right (T2 (reshape m vec)) else errmsg
[-> errmsg
_ where
= bimap fromIntegral fromIntegral (dataOffsets meta)
(startpos, endpos) = Left ("Wrong size while reading " ++ show meta)
errmsg -- it would maybe be better to load them "in order" with splitAt but
-- the loading is fast enough with this now that the BS is cast directly
= byteStringToVector (BS.drop startpos (BS.take endpos bs))
vec
-- getting layer weights is straight forward. some matrices need to be transposed.
getMat :: TensorMap -> String -> Either String M
= case KM.lookup (K.fromString s) tm of
getMat tm s Just (T2 m)) -> Right m
(-> Left ("Error loading " ++ s)
_
getVec :: TensorMap -> String -> Either String V
= case KM.lookup (K.fromString s) tm of
getVec tm s Just (T1 v)) -> Right v
(-> Left ("Error loading " ++ s)
_
getTELayer :: TensorMap -> Either String TokenEmbedding
= do
getTELayer tm <- getMat tm "wte.weight"
m return (TokenEmbedding (tr m))
getPELayer :: TensorMap -> Either String PositionEmbedding
= do
getPELayer tm <- getMat tm "wpe.weight"
m return (PositionEmbedding (tr m))
getLayerNorm :: TensorMap -> String -> Either String LayerNorm
= do
getLayerNorm tm s <- getVec tm (s ++ ".weight")
w <- getVec tm (s ++ ".bias")
b return (LayerNorm w b)
getAttention :: TensorMap -> String -> Either String Attention
= do
getAttention tm layer <- getMat tm (layer ++ ".attn.c_attn.weight")
aw <- getVec tm (layer ++ ".attn.c_attn.bias")
ab <- getMat tm (layer ++ ".attn.c_proj.weight")
pw <- getVec tm (layer ++ ".attn.c_proj.bias")
pb return (Attention (tr aw) ab (tr pw) pb)
getMLP :: TensorMap -> String -> Either String MLP
= do
getMLP tm layer <- getMat tm (layer ++ ".mlp.c_fc.weight")
aw <- getVec tm (layer ++ ".mlp.c_fc.bias")
ab <- getMat tm (layer ++ ".mlp.c_proj.weight")
pw <- getVec tm (layer ++ ".mlp.c_proj.bias")
pb return (MLP (tr aw) ab (tr pw) pb)
getBlock :: TensorMap -> Int -> Either String Block
= do
getBlock tm i let prefix = "h." ++ show i
<- getLayerNorm tm (prefix ++ ".ln_1")
le1 <- getLayerNorm tm (prefix ++ ".ln_2")
le2 <- getAttention tm prefix
at <- getMLP tm prefix
mp return (Block le1 at le2 mp)
constructModel :: TensorMap -> Either String GPT
= do
constructModel tm <- getPELayer tm
pe <- getTELayer tm
te <- mapM (getBlock tm) [11, 10 .. 0]
block <- getLayerNorm tm "ln_f"
ln return (GPT pe te block ln)
getTensorMap :: SafeTensors -> Either String TensorMap
= mapM (bytesToTensor (binaryData ten)) (metadata ten)
getTensorMap ten
parseModel :: BL.ByteString -> Either String GPT
= do
parseModel bytes <- parseTensors bytes
safeTensors <- getTensorMap safeTensors
tensorMap
constructModel tensorMap
readModel :: String -> IO (Either String GPT)
= do
readModel filePath <- BL.readFile filePath
contents return (parseModel contents)