Text generation Using a Markov Chain With the Help of Reinforcement Learning

  • Tanesh Balodi
  • Dec 05, 2020
  • Machine Learning
  • Updated on: Dec 05, 2020
Text generation Using a Markov Chain With the Help of Reinforcement Learning title banner

What is Reinforcement Learning?


Reinforcement learning is something that has seen a lot of advancement in machine learning and deep learning, what makes it different from supervised learning is that reinforcement learning does not have labels associated with it, instead it learns from the experience, so when our model predicts something with good accuracy, we reward it with positive remarks and when it does not perform up to the threshold value, it is rewarded with negative remarks. We are going to understand the Markov chain, which uses the Markov decision process, shortly known as the MDP.

An image representing supervised learning as well as reinforcement learning

Supervised Vs Reinforcement learning

Defining Markov Chain


Markov Chain is a stochastic approach to find the solution, consider an example, imagine having a field which is used for multi-sports activities. So the probability of the field being used for football the other day while it is being used for cricket currently is something that we could represent through the Markov Chain.


The representation of Markov chain with the help of directed graph like tree structure


Markov Chain Representation

Above is an example of the representation of the Markov Chain, here different probabilities of different sports being played on one field is shown, let’s visualize in adjacency matrix form.

Adjacency matrix representation of the Markov chain


                                Adjacency Matrix Representation

 Let’s take one part out of the whole Representation


The representation of Markov chain by taking an example related to conditional probability

The representation of Markov chain 

Here, there is a 60% probability of a field being used for Handball after Football and a 10% probability of Football being played on the field again. Similarly, there is a 70% probability of Football being played after Handball and a 10% probability of Handball being played again.


Therefore, What if our model is able to create such probabilities of every event possible and predict the next movement, text, etc, on the basis of what it has learned from these probabilities, This is how we are going to use the Markov Chain for Text Generation.


the formula for probability of the next word or sentence with the help of Markov chain in NLP


In Simple Terms,


P(Xn+1= x | Xn= xn)


The above equation states that the Future state is only dependant upon the current state, therefore it could be predicted by looking at the current state.


What are the Properties of the Markov Chain?


  1. The Next state is not dependent on all the previous states, but only the current state. 


  1. The sum of all the weights of the outgoing arrow should be equal to 1 as you can also see in the above representation of the Markov Chain.


  1. After some Random walk or a finite number of steps, we will get a stationary distribution or the equilibrium state, which means this distribution does not change with time and remain stationary or constant.


Code For Text Generation Using a Markov Chain


We will be generating text using the Markov chain, we have taken random articles consisting of a few thousand words as a dataset, you can make your own dataset by doing the same. Following are the steps involved;


  1. Importing important libraries and opening dataset

import numpy as npf = open('../datasets/sherlock.txt')

text = f.read()


blob = text[3433:]

  1. Tokenization of words

blob = [each.strip() for each in blob.split('\n') if each]

blob = ' '.join(blob)

from nltk.tokenize import word_tokenize



-> 21758

  1. Creating and printing length of sets.

states = set(blob) # Vocab



-> 96

  1. Creating a transition matrix

T = {} # Transition Matrix

n = 5

for i in range(len(blob) - n):

    ngram =  blob[i:i+n]

    next_state = blob[i+n]

    T_context = T.setdefault(ngram, {})

    T_context[next_state] = T_context.setdefault(next_state, 0) + 1

  1. Converting into probabilities

# Converting to probabilities

for row in T:

    s = sum(T[row].values())

    for val in T[row]:

        T[row][val] = T[row][val]/s

  1. Creating a stationary distribution

values = []

for _ in range(10000):

    r = np.random.random()

    if r <= 0.3:


    elif r <= 0.7:





values = np.array(values)

for f in [0.3, 0.7, 1]:


  1. Temperature sampling

def temperature_sampling(probabilities, temp=1):

    probabilities = np.asarray(probabilities)

    smoothened_probs = np.exp(np.log(probabilities) / temp)

    return list(smoothened_probs / smoothened_probs.sum())

probs = [0.2, 0.4, 0.1, 0.03, 0.07]

sampled = temperature_sampling(probs, 2)

from matplotlib import pyplot as plt


plt.plot(probs, 'b-', label='Prior')

plt.plot(sampled, 'g--', label='Smoothened')




Graphical representation of prior and smoothened while doing its temperature sampling

Graphical representation of prior and smoothened

  1. Predicting state

def predict_state(ngram, diversity=1):

    if T.get(ngram) is None:

        return ' '

    mapped_ngram = T[ngram]

    mapped_states = list(mapped_ngram.keys())

    probabilities = list(mapped_ngram.values())

    diversified_probs = temperature_sampling(probabilities, temp=diversity)

#     print(sorted(probabilities, reverse=True)[:4])

    return np.random.choice(mapped_states, p=diversified_probs)

  1. Generating text

def generate(initial=None, size=1000, diversity=1):

    sentence = ''

    if initial is None:

        initial = int(np.random.random() * (len(data) - n))

        initial = data[initial:initial+n]


    sentence += initial

    for i in range(size):

        pred = predict_state(initial, diversity=diversity)

        sentence += pred

        initial = sentence[-n:]

    return sentence

            print(generate('In th', diversity=0.5))


Generating text with the help of markov chain using python

To get different results, try changing ‘diversity’.




Markov Chain is indeed a very efficient way of text generation as you may also conclude, other methods that are also based on reinforcement learning are RNN, LSTM, and GRU. Some API like Google BERT and GPT-2 are also in use but they are complex to understand, on the other hand, the Approach of Markov chain is quite simple with easy implementation.