This tutorial presents an example of application of one-to-one RNN model applied to text generation using MXNet R package.

Example based on Obama’s speech.

Load some packages

library("readr")
library("dplyr")
library("plotly")
library("stringr")
library("stringi")
library("scales")
library("mxnet")

Data preparation

Data preparation is performed by the script: data_preprocessing_one_to_one.R.

The following steps are executed:

corpus_bucketed_train <- readRDS(file = "../data/train_buckets_one_to_one.rds")
corpus_bucketed_test <- readRDS(file = "../data/eval_buckets_one_to_one.rds")

vocab <- length(corpus_bucketed_test$dic)

### Create iterators
batch.size = 32

train.data <- mx.io.bucket.iter(buckets = corpus_bucketed_train$buckets, 
                                batch.size = batch.size, 
                                data.mask.element = 0, shuffle = TRUE)

eval.data <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets, 
                               batch.size = batch.size,
                               data.mask.element = 0, shuffle = FALSE)

Model architecture

A one-to-one model configuration is specified since for each character, we want to predict the next one. For a sequence of length 100, there are also 100 labels, corresponding the same sequence of characters but offset by a position of +1.

rnn_graph <- rnn.graph.unroll(seq_len = 2,
                              num_rnn_layer = 1, 
                              num_hidden = 96,
                              input_size = vocab,
                              num_embed = 64, 
                              num_decode = vocab,
                              masking = F, 
                              loss_output = "softmax",
                              dropout = 0.2, 
                              ignore_label = -1,
                              cell_type = "lstm",
                              output_last_state = F,
                              config = "one-to-one")
%0 1 data data 2 SwapAxis swap_pre 1->2 2X32 3 Embedding embed 2->3 32X2 4 SliceChannel split0 3->4 64X32X2 5 FullyConnected t1.l1.i2h 384 4->5 13 FullyConnected t2.l1.i2h 384 4->13 6 SliceChannel t1.l1.slice 5->6 384X32 7 Activation activation3 sigmoid 6->7 8 Activation activation0 sigmoid 6->8 9 Activation activation1 tanh 6->9 12 elemwise_mul _mul1 7->12 96X32 10 elemwise_mul _mul0 8->10 96X32 9->10 96X32 11 Activation activation4 tanh 10->11 96X32 19 elemwise_mul _mul2 10->19 96X32 11->12 96X32 14 FullyConnected t2.l1.h2h 384 12->14 96X32 26 Concat concat 12->26 96X32 15 elemwise_add _plus0 13->15 384X32 14->15 384X32 16 SliceChannel t2.l1.slice 15->16 384X32 17 Activation activation8 sigmoid 16->17 18 Activation activation7 sigmoid 16->18 20 Activation activation5 sigmoid 16->20 21 Activation activation6 tanh 16->21 25 elemwise_mul _mul4 17->25 96X32 18->19 96X32 23 elemwise_add _plus1 19->23 96X32 22 elemwise_mul _mul3 20->22 96X32 21->22 96X32 22->23 96X32 24 Activation activation9 tanh 23->24 96X32 24->25 96X32 25->26 96X32 27 Reshape rnn_reshape 26->27 96X64 28 _copy mask 27->28 96X32X2 29 SwapAxis swap_post 28->29 96X32X2 30 Reshape reshape0 29->30 96X2X32 31 FullyConnected decode 82 30->31 96X64 33 SoftmaxOutput loss 31->33 82X64 32 Reshape reshape1 32->33 64

Fit a LSTM model

Unroll the RNN to the length of the input sequence.

ctx <- mx.cpu()

initializer <- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 3)

optimizer <- mx.opt.create("adadelta", rho = 0.9, eps = 1e-5, wd = 1e-8,
                           clip_gradient = 5, rescale.grad = 1/batch.size)

logger <- mx.metric.logger()
epoch.end.callback <- mx.callback.log.train.metric(period = 1, logger = logger)
batch.end.callback <- mx.callback.log.train.metric(period = 50)

mx.metric.custom_nd <- function(name, feval) {
  init <- function() {
    c(0, 0)
  }
  update <- function(label, pred, state) {
    m <- feval(label, pred)
    state <- c(state[[1]] + 1, state[[2]] + m)
    return(state)
  }
  get <- function(state) {
    list(name=name, value=(state[[2]]/state[[1]]))
  }
  ret <- (list(init=init, update=update, get=get))
  class(ret) <- "mx.metric"
  return(ret)
}

mx.metric.Perplexity <- mx.metric.custom_nd("Perplexity", function(label, pred) {
  label <- mx.nd.reshape(label, shape = -1)
  label_probs <- as.array(mx.nd.choose.element.0index(pred, label))
  batch <- length(label_probs)
  NLL <- -sum(log(pmax(1e-15, as.array(label_probs)))) / batch
  Perplexity <- exp(NLL)
  return(Perplexity)
})

system.time(
  model <- mx.model.buckets(symbol = symbol,
                            train.data = train.data, eval.data = eval.data, 
                            num.round = 5, ctx = ctx, verbose = TRUE,
                            metric = mx.metric.Perplexity, 
                            initializer = initializer, optimizer = optimizer, 
                            batch.end.callback = NULL, 
                            epoch.end.callback = epoch.end.callback)
)
##    user  system elapsed 
## 1447.64  945.27  459.35
mx.model.save(model, prefix = "../models/model_one_to_one_lstm_cpu", iteration = 1)
12345456789
traineval

Inference on test data

Setup inference data. Need to apply preprocessing to inference sequence and convert into a infer data iterator.

The parameters output_last_state is set to TRUE in order to access the state of the RNN cells when performing inference.

ctx <- mx.cpu()
batch.size <- 1

corpus_bucketed_train <- readRDS(file = "../data/train_buckets_one_to_one.rds")
dic <- corpus_bucketed_train$dic
rev_dic <- corpus_bucketed_train$rev_dic

infer_raw <- c("The United States are")
infer_split <- dic[strsplit(infer_raw, '') %>% unlist]
infer_length <- length(infer_split)

symbol.infer.ini <- rnn.graph.unroll(seq_len = infer_length,
                                     num_rnn_layer = 2, 
                                     num_hidden = 96,
                                     input_size = vocab,
                                     num_embed = 64, 
                                     num_decode = vocab,
                                     masking = F, 
                                     loss_output = "softmax",
                                     dropout = 0.2, 
                                     ignore_label = -1,
                                     cell_type = "lstm",
                                     output_last_state = T,
                                     config = "one-to-one")

symbol.infer <- rnn.graph.unroll(seq_len = 1,
                                 num_rnn_layer = 2, 
                                 num_hidden = 96,
                                 input_size = vocab,
                                 num_embed = 64, 
                                 num_decode = vocab,
                                 masking = F, 
                                 loss_output = "softmax",
                                 dropout = 0.2, 
                                 ignore_label = -1,
                                 cell_type = "lstm",
                                 output_last_state = T,
                                 config = "one-to-one")

Inference with most likely term

Here the predictions are performed by picking the character whose associated probablility is the highest.

model <- mx.model.load(prefix = "../models/model_one_to_one_lstm_cpu", iteration = 1)

predict <- numeric()
data = mx.nd.array(matrix(infer_split))

infer <- mx.infer.rnn.one.unroll(infer.data = data, 
                                 symbol = symbol.infer.ini,
                                 num_hidden = 96,
                                 arg.params = model$arg.params,
                                 aux.params = model$aux.params,
                                 init_states = NULL,
                                 ctx = ctx)

pred_prob <- mx.nd.slice.axis(infer[[1]], axis=0, begin = infer_length-1, end = infer_length)
pred <- mx.nd.argmax(data = pred_prob, axis = 1, keepdims = T)
predict <- c(predict, as.numeric(as.array(pred)))

for (i in 1:100) {
  
  infer <- mx.infer.rnn.one.unroll(infer.data = pred, 
                                   symbol = symbol.infer,
                                   num_hidden = 96,
                                   arg.params = model$arg.params,
                                   aux.params = model$aux.params,
                                   init_states = infer[-1],
                                   ctx = ctx)
  
  pred <- mx.nd.argmax(data = infer$loss_output, axis = 1, keepdims = T)
  predict <- c(predict, as.numeric(as.array(pred)))
  
}

predict_txt <- paste0(rev_dic[as.character(predict)], collapse = "")
predict_txt_tot <- paste0(infer_raw, predict_txt, collapse = "")

Generated sequence: The United States are the street the problem that the problem that the problem that the problem that the problem that the

Key ideas appear somewhat overemphasized.

Inference from random sample

Noise is now inserted in the predictions by sampling each character based on their modeled probability.

set.seed(44)

infer_raw <- c("The United States are")
infer_split <- dic[strsplit(infer_raw, '') %>% unlist]
infer_length <- length(infer_split)

predict <- numeric()

infer <- mx.infer.rnn.one.unroll(infer.data = data, 
                                 symbol = symbol.infer.ini,
                                 num_hidden = 96,
                                 arg.params = model$arg.params,
                                 aux.params = model$aux.params,
                                 init_states = NULL,
                                 ctx = ctx)

pred_prob <- as.numeric(as.array(mx.nd.slice.axis(
  infer[[1]], axis=0, begin = infer_length-1, end = infer_length)))
pred <- sample(length(pred_prob), prob = pred_prob, size = 1) - 1
predict <- c(predict, pred)

for (i in 1:100) {
  
  infer <- mx.infer.rnn.one.unroll(infer.data = mx.nd.array(array(pred, dim = c(1,1))), 
                                   symbol = symbol.infer,
                                   num_hidden = 96,
                                   arg.params = model$arg.params,
                                   aux.params = model$aux.params,
                                   init_states = infer[-1],
                                   ctx = ctx)
  
  pred_prob <- as.numeric(as.array(infer[[1]]))
  pred <- sample(length(pred_prob), prob = pred_prob, size = 1, replace = T) - 1
  predict <- c(predict, pred)
}

predict_txt <- paste0(rev_dic[as.character(predict)], collapse = "")
predict_txt_tot <- paste0(infer_raw, predict_txt, collapse = "")

Generated sequence: The United States are not all the politics and accusies wutter of our faith. Let willing idea that improve maving with the

Now we get a more alembicated political speech.