Monday, April 3, 2017

Getting the top n most probable sentences using beam search

This is a continuation from a previous blog post on single sentence beam search.

Sometimes it is not enough to just generate the most probable sentence using a language model. Sometimes you want to generate the top 3 most probable sentences instead. In that case we need to modify our beam search a bit. We will make the function a generator that returns a sequence of sentences in order of probability instead of just returning a single most probable sentence. Here are the changes we need to make:

In the single sentence version, we were getting the most probable prefix in the current beam and checking if it is complete. If it is, then we return it and stop there. Instead, we will now not stop until the current beam is empty (or until the caller stops requesting for more sentences). After returning the most probable prefix we will check the second most probable prefix and keep on returning complete prefixes until we either find one which is not complete or we return all the beam. In the case that we return the whole beam then the algorithm stops there as there is nothing left with which to generate new prefixes. This means that the beam width gives a limit on the number of sentences that can be returned. If we do not return all the beam then we continue generating prefixes with the remainder.

In the case that some complete sentences were returned, they need to also be removed from the beam before we continue generating. Since the beam is implemented as a min-first heap queue (min-first because we want to pop the least probable prefix quickly when the beam becomes bigger than the beam width) then we cannot remove the highest probable complete sentence quickly as well. In order to do this, we first turn the heap into a list which is sorted by probability and then start popping out the items at the end if they are complete sentences. Following this, we will then heapify the list back into a min-first heap queue and continue as normal. This sorting and reheapifying should not impact on the performance too much if the beam width is relatively small.

If the clip length is reached then the whole beam is immediately returned in order of probability. This is because as soon as one prefix is equal to the allowed maximum then that means that the entire beam consists of
  1. incomplete sentences that are also as long as the allowed maximum (since all the prefixes grow together)
  2. complete sentences that were found before but which do not have a maximum probability
After returning all the incomplete sentences (they have to be returned since they cannot be extended further) then the complete sentences will also be returned as they will become the most probable sentences of what is left in the beam. The assumption is that if the complete sentence was a nonsensical one, then it wouldn't have remained in the beam so might as well return it as well rather than lose it.

Here is the modified Python 3 code:

import heapq

class Beam(object):

    def __init__(self, beam_width, init_beam=None):
        if init_beam is None:
            self.heap = list()
        else:
            self.heap = init_beam
            heapq.heapify(self.heap) #make the list a heap
        self.beam_width = beam_width

    def add(self, prob, complete, prefix):
        heapq.heappush(self.heap, (prob, complete, prefix))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)

    def __iter__(self):
        return iter(self.heap)


def beamsearch(probabilities_function, beam_width=10, clip_len=-1):
    prev_beam = Beam(beam_width)
    prev_beam.add(1.0, False, [ '<start>' ])

    while True:
        curr_beam = Beam(beam_width)

        #Add complete sentences that do not yet have the best
probability to the current beam, the rest prepare to add more words to
them.
        for (prefix_prob, complete, prefix) in prev_beam:
            if complete == True:
                curr_beam.add(prefix_prob, True, prefix)
            else:
                #Get probability of each possible next word for the
incomplete prefix.
                for (next_prob, next_word) in probabilities_function(prefix):
                    if next_word == '<end>': #if next word is the end
token then mark prefix as complete and leave out the end token
                        curr_beam.add(prefix_prob*next_prob, True, prefix)
                    else: #if next word is a non-end token then mark
prefix as incomplete
                        curr_beam.add(prefix_prob*next_prob, False,
prefix+[next_word])

        sorted_beam = sorted(curr_beam) #get all prefixes in current beam sorted by probability
        any_removals = False
        while True:
            (best_prob, best_complete, best_prefix) = sorted_beam[-1] #get highest probability prefix
            if best_complete == True or len(best_prefix)-1 ==
clip_len: #if most probable prefix is a complete sentence or has a length that exceeds the clip length (ignoring the start token) then yield it
                yield (best_prefix[1:], best_prob) #yield best sentence without the start token and together with its probability
                sorted_beam.pop() #remove the yielded sentence and check the next highest probability prefix
                any_removals = True
                if len(sorted_beam) == 0: #if there are no more sentences in the beam then stop checking
                    break
            else:
                break

        if any_removals == True: #if there were any changes in the current beam then...
            if len(sorted_beam) == 0: #if there are no more prefixes in the current beam (due to clip length being reached) then end the beam search
                break
            else: #otherwise set the previous beam to the modified current beam
                prev_beam = Beam(beam_width, sorted_beam)
        else: #if the current beam was left unmodified then assign it to the previous beam as is
            prev_beam = curr_beam