36 The Baum-Welch Algorithm
(def categories '(N V Adj Adv P stop))
(def vocabulary '(Call me Ishmael))
(defn logsumexp [& log-vals]
(let [mx (apply max log-vals)]
(+ mx
(Math/log2
(apply +
(map (fn [z] (Math/pow 2 z))
(map (fn [x] (- x mx))
log-vals)))))))
(defn flip [p]
(if (< (rand 1) p)
true
false))
(defn sample-categorical [outcomes params]
(if (flip (first params))
(first outcomes)
(sample-categorical (rest outcomes)
(normalize (rest params)))))
(defn score-categorical [outcome outcomes params]
(if (empty? params)
(throw "no matching outcome")
(if (= outcome (first outcomes))
(Math/log2 (first params))
(score-categorical outcome (rest outcomes) (rest params)))))
(defn normalize [params]
(let [sum (apply + params)]
(map (fn [x] (/ x sum)) params)))
(defn sample-gamma [shape scale]
(apply + (repeatedly
shape (fn []
(- (Math/log2 (rand))))
)))
(defn sample-dirichlet [pseudos]
(let [gammas (map (fn [sh]
(sample-gamma sh 1))
pseudos)]
(normalize gammas)))
(defn update-context [order old-context new-symbol]
(if (>= (count old-context) order)
(throw "Context too long!")
(if (= (count old-context) (- order 1))
(concat (rest old-context) (list new-symbol))
(concat old-context (list new-symbol)))))
(defn hmm-unfold [transition observation order context current stop?]
(if (stop? current)
(list current)
(let [new-context (update-context
order
context
current)
nxt (transition new-context)]
(cons [current (observation current)]
(hmm-unfold
transition
observation
n-gram-order
new-context
nxt
stop?)))))
(defn all-but-last [l]
(cond (empty? l) (throw "bad thing")
(empty? (rest l)) '()
:else (cons (first l) (all-but-last (rest l)))))
We first consider the maximum likelihood solution to the HMM parameter estimation problem. The problem is to maximize the probability of the data given the paramters of our HMM.
\[\DeclareMathOperator*{\argmax}{arg\,max} \underset {\boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T} }\argmax \log \prod_{\vec{w} \in \mathbf{C} } \sum_{\vec{c}} \prod_{i=1}^{|\vec{w}|} \Pr(c^{(i)} | c^{(i-1)}, \vec{\theta}_{T,c^{(i-1)}}) \Pr(w^{(i)} | c^{(i)}, \vec{\theta}_{O,c^{(i)}})\]The problem is that the sum over derivations in this equation makes any straightforward maximization impossible. What are we to do?
One insight comes from considering what we could do if we knew the “true” derivations of individual sentences. This is sometimes referred to as the complete data problem and we first consider this.
The Complete Data Problem
Recall from Hidden Markov Models that it is easy to find the joint probability of corpus of sentences together with some particular set of derivations for each sentence. In the setting of maximum likelihood estimation, estimating the value of some set of parameters knowing the values of both the observed and (typically) latent random variables is called the complete data estimation problem. In this view, the values for the derivations are considered to be a kind of data that just happens to be missing.
If we know the derivations, then our optimization criterion reduces to the following expression.
\[\underset {\boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T} }\argmax \Pr(\mathbf{C}, \mathbf{D} \mid \boldsymbol{\theta}_T, \boldsymbol{\theta}_O)\]Which can be expanded as
\[\DeclareMathOperator*{\argmax}{arg\,max} \underset {\boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T} }\argmax \log \prod_{\vec{w}, \vec{c} \in \mathbf{C}, \mathbf{D} } \prod_{i=1}^{|\vec{w}|} \Pr(c^{(i)} | c^{(i-1)}, \vec{\theta}_{T,c^{(i-1)}}) \Pr(w^{(i)} | c^{(i)}, \vec{\theta}_{O,c^{(i)}})\]Since all of the distributions in an HMM are categorical distributions, the complete data likelihood, otherwise known as the joint probability of the corpus and derivations is given by the followin expression.
\[\underset {\boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T} }\argmax \prod_{c \in \mathcal{S}, w \in V} [\theta_{ w \mid c} ]^{n_{w \mid c}} \prod_{c, c^\prime \in \mathcal{S}} [\theta_{c \rightarrow c^\prime}]^{n_{c \rightarrow c^\prime}}\]This is a familiar optimization problem of the form we saw in Maximum Likelihoood. Following the logic in that unit, the optimal value, for example, for a particular transition probability is given by it’s renormalized complete-data count.
\[\hat{\theta}_{c \rightarrow c^\prime} = \frac{n_{c \rightarrow c^\prime}}{\sum_{c^{\prime\prime} \in \mathcal{S}} n_{c \rightarrow c^{\prime\prime}}}\]Thus, if we knew the values of the derivations for each sentence, it would be a trivial task to estimate the parameters of the model. This suggests an idea: perhaps we can alternate between optimizing the values for the “incomplete” items (i.e., the derivations) and the parameters.
The Incomplete Data Estimation Problem: First Pass
The core idea behind the Baum-Welch algorithm is to alternate between optimizing the values of \(\boldsymbol{\theta}_{T}\) and \(\boldsymbol{\theta}_{O}\) on one hand and \(\mathbf{D}\) on the other. As a first pass, we consider optimizing this directly. First, we fix \(\boldsymbol{\theta}^{\mathrm{old}}_{T}\) and \(\boldsymbol{\theta}^{\mathrm{old}}_{O}\) (perhaps initializing them randomly or uniformly) and compute the optimal value for \(\mathbf{D}\).
\[\DeclareMathOperator*{\argmax}{arg\,max} \underset {\mathbf{D} }\argmax \log \prod_{\vec{w} \in \mathbf{C} } \sum_{\vec{c}} \prod_{i=1}^{|\vec{w}|} \Pr(c^{(i)} | c^{(i-1)}, \vec{\theta}^{\mathrm{old}}_{T,c^{(i-1)}}) \Pr(w^{(i)} | c^{(i)}, \vec{\theta}^{\mathrm{old}}_{O,c^{(i)}})\]This consists of just find the best sequence of categories \(\vec{c}\) for each sentence in the corpus and with fixed parameters, this can be done using the Viterbi algorithm.
Next, we update our parameters using the discovered values of \(\hat{\mathbf{D}}\).
\[\DeclareMathOperator*{\argmax}{arg\,max} \underset {\boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T} }\argmax \log \prod_{\vec{w} \in \mathbf{C} } \sum_{\vec{\hat{c}}} \prod_{i=1}^{|\vec{w}|} \Pr(\hat{c}^{(i)} | \hat{c}^{(i-1)}, \vec{\theta}_{T,\hat{c}^{(i-1)}}) \Pr(w^{(i)} | \hat{c}^{(i)}, \vec{\theta}_{O,\hat{c}^{(i)}})\]Of course, now with our data “completed,” our new parameters are just the renormalized counts.
\[\theta^{\mathrm{new}}_{c \rightarrow c^\prime} = \frac{\hat{n}_{c \rightarrow c^\prime}}{\sum_{c^{\prime\prime} \in \mathcal{S}} \hat{n}_{c \rightarrow c^{\prime\prime}}}\]We repeat this process, alternatively updating the parameters and the derivations until some stopping criterion is reached.
Using Expectations: The Baum-Welch Algorithm
The process of iterating between optimizing derivations provides no guarantees that we will be able to find the optimal values for the parameters. If we initialize poorly, or follow a bad path through optimization space, we can accidentally end up at a local optimum. One reason for this, is that we are using only a single estimate for \(\mathbf{\hat{D}}\). If our parameters led us to choose a bad set of derivations, we can end up estimating an even worse set of new parameter values.
One way to mitigate this risk is to attempt to hedge our guesses about the set of derivations by using the distribution over derivations instead of a single one. Recall that our complete data allows us to compute the counts of each transition and observation so that we can do maximum likelihood estimation on the result. The idea behind the Baum-Welch algorithm is to do our updates based on the expected counts instead.
\[\theta^{\mathrm{new}}_{c \rightarrow c^\prime} = \frac{\mathbb{E}[n_{c \rightarrow c^\prime}]}{\sum_{c^{\prime\prime} \in \mathcal{S}} \mathbb{E}[n_{c \rightarrow c^{\prime\prime}}]}\]These expectations average over some distribution over derivations. What distribution should we use? The most natural distribution is to use the posterior distribution \(\Pr(\mathbf{D} \mid \mathbf{C}, \boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T})\).
Luckily, we have studied exactly how to compute this posterior efficiently. How can we compute, for example, the expected count of a transition from \(c_{1} \rightarrow c_{2}\)? We want to compute the total number of times this transition occurs in all derivations for all strings in our corpus, and weight each occurrence by the posterior probability of that derivation. Another way to think about this is that if we were absolutely certain which derivation was the right one, as was the case in the complete data estimation problem, each occurrence of a transition would receive a count of \(1\). Since we are unsure that the derivation, we don’t give each occurrence a full count of \(1\) but instead count it fractionally based on its posterior probability—a number between \(0\) and \(1\) in general.
First, let’s define the posterior probability of a transition from \(c_{1}\) to \(c_2\) at time step \(t\) to time step \(t+1\) in a given, particular string \(\vec{w}\).
\[\gamma_{t}(c_{1}^{(t)}, c_{2}^{(t+1)}) = \Pr(c_{1}^{(t)}, c_{2}^{(t+1)}, w^{(1)}, \cdots, w^{(\vec{w})}) = \mathbf{fwd}(c_{1}^{(t)})\theta_{c_{1} \rightarrow c_{2}} \theta_{w^{(t+1)} \mid c_{2}}\mathbf{bkwd}({c_{2}^{(t+1)}})\]We normalize this with the marginal probability of the word sequence in order to compute our desired target posterior distribution.
\[\zeta_{t}(c_{1}^{(t)}, c_{2}^{(t+1)}) = \Pr(c_{1}^{(t)}, c_{2}^{(t+1)} \mid w^{(1)}, \cdots, w^{(\vec{w})}) = \frac{\gamma(c_{1}^{(t)}, c_{2}^{(t+1)})}{\sum_{i} \sum_{j} \gamma(c_{i}^{(t)}, c_{j}^{(t+1)} )}\]We can compute the expected count by summing over each position in each sentence and each sentence in the corpus.
\[\mathbb{E}_{\Pr(\mathbf{D} \mid \mathbf{C}, \boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T})}[n_{c_{1} \rightarrow c_{2}}] = \sum_{\vec{w} \in \mathbf{C}} \sum_{t=1}^{|\vec{w} |-1} \zeta_{t}(c_{1}^{(t)}, c_{2}^{(t+1)})\]We then renormalize this expected count by the expected count of all transitions from \(c_{1}\).
\[\theta^{\mathrm{new}}_{c_{1} \rightarrow c_{2}} = \frac{\mathbb{E}_{\Pr(\mathbf{D} \mid \mathbf{C}, \boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T})}[n_{c_{1} \rightarrow c_{2}}]}{\sum_{c_{n} \in \mathcal{S}} \mathbb{E}_{\Pr(\mathbf{D} \mid \mathbf{C}, \boldsymbol{\theta}_{O}, \boldsymbol{\theta}_{T})}[n_{c \rightarrow c_{n}}]}\]Or expressed using \(\zeta\).
\[\theta^{\mathrm{new}}_{c_{1} \rightarrow c_{2}} = \frac{ \sum_{\vec{w} \in \mathbf{C}} \sum_{t=1}^{|\vec{w} |-1} \zeta_{t}(c_{1}^{(t)}, c_{2}^{(t+1)}) }{\sum_{c_{n} \in \mathcal{S}} \sum_{\vec{w} \in \mathbf{C}} \sum_{t=1}^{|\vec{w} |-1} \zeta_{t}(c_{1}^{(t)}, c_{n}^{(t+1)}) }\]