This tutorial presents an example of application of one-to-one RNN model applied to text generation using MXNet R package.
Example based on Obama’s speech.
Load some packages
library("readr")
library("dplyr")
library("plotly")
library("stringr")
library("stringi")
library("scales")
library("mxnet")
Data preparation is performed by the script: data_preprocessing_one_to_one.R
.
The following steps are executed:
corpus_bucketed_train <- readRDS(file = "../data/train_buckets_one_to_one.rds")
corpus_bucketed_test <- readRDS(file = "../data/eval_buckets_one_to_one.rds")
vocab <- length(corpus_bucketed_test$dic)
### Create iterators
batch.size = 32
train.data <- mx.io.bucket.iter(buckets = corpus_bucketed_train$buckets,
batch.size = batch.size,
data.mask.element = 0, shuffle = TRUE)
eval.data <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets,
batch.size = batch.size,
data.mask.element = 0, shuffle = FALSE)
A one-to-one model configuration is specified since for each character, we want to predict the next one. For a sequence of length 100, there are also 100 labels, corresponding the same sequence of characters but offset by a position of +1.
rnn_graph <- rnn.graph.unroll(seq_len = 2,
num_rnn_layer = 1,
num_hidden = 96,
input_size = vocab,
num_embed = 64,
num_decode = vocab,
masking = F,
loss_output = "softmax",
dropout = 0.2,
ignore_label = -1,
cell_type = "lstm",
output_last_state = F,
config = "one-to-one")
Unroll the RNN to the length of the input sequence.
ctx <- mx.cpu()
initializer <- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 3)
optimizer <- mx.opt.create("adadelta", rho = 0.9, eps = 1e-5, wd = 1e-8,
clip_gradient = 5, rescale.grad = 1/batch.size)
logger <- mx.metric.logger()
epoch.end.callback <- mx.callback.log.train.metric(period = 1, logger = logger)
batch.end.callback <- mx.callback.log.train.metric(period = 50)
mx.metric.custom_nd <- function(name, feval) {
init <- function() {
c(0, 0)
}
update <- function(label, pred, state) {
m <- feval(label, pred)
state <- c(state[[1]] + 1, state[[2]] + m)
return(state)
}
get <- function(state) {
list(name=name, value=(state[[2]]/state[[1]]))
}
ret <- (list(init=init, update=update, get=get))
class(ret) <- "mx.metric"
return(ret)
}
mx.metric.Perplexity <- mx.metric.custom_nd("Perplexity", function(label, pred) {
label <- mx.nd.reshape(label, shape = -1)
label_probs <- as.array(mx.nd.choose.element.0index(pred, label))
batch <- length(label_probs)
NLL <- -sum(log(pmax(1e-15, as.array(label_probs)))) / batch
Perplexity <- exp(NLL)
return(Perplexity)
})
system.time(
model <- mx.model.buckets(symbol = symbol,
train.data = train.data, eval.data = eval.data,
num.round = 5, ctx = ctx, verbose = TRUE,
metric = mx.metric.Perplexity,
initializer = initializer, optimizer = optimizer,
batch.end.callback = NULL,
epoch.end.callback = epoch.end.callback)
)
## user system elapsed
## 1447.64 945.27 459.35
mx.model.save(model, prefix = "../models/model_one_to_one_lstm_cpu", iteration = 1)
Setup inference data. Need to apply preprocessing to inference sequence and convert into a infer data iterator.
The parameters output_last_state
is set to TRUE
in order to access the state of the RNN cells when performing inference.
ctx <- mx.cpu()
batch.size <- 1
corpus_bucketed_train <- readRDS(file = "../data/train_buckets_one_to_one.rds")
dic <- corpus_bucketed_train$dic
rev_dic <- corpus_bucketed_train$rev_dic
infer_raw <- c("The United States are")
infer_split <- dic[strsplit(infer_raw, '') %>% unlist]
infer_length <- length(infer_split)
symbol.infer.ini <- rnn.graph.unroll(seq_len = infer_length,
num_rnn_layer = 2,
num_hidden = 96,
input_size = vocab,
num_embed = 64,
num_decode = vocab,
masking = F,
loss_output = "softmax",
dropout = 0.2,
ignore_label = -1,
cell_type = "lstm",
output_last_state = T,
config = "one-to-one")
symbol.infer <- rnn.graph.unroll(seq_len = 1,
num_rnn_layer = 2,
num_hidden = 96,
input_size = vocab,
num_embed = 64,
num_decode = vocab,
masking = F,
loss_output = "softmax",
dropout = 0.2,
ignore_label = -1,
cell_type = "lstm",
output_last_state = T,
config = "one-to-one")
Here the predictions are performed by picking the character whose associated probablility is the highest.
model <- mx.model.load(prefix = "../models/model_one_to_one_lstm_cpu", iteration = 1)
predict <- numeric()
data = mx.nd.array(matrix(infer_split))
infer <- mx.infer.rnn.one.unroll(infer.data = data,
symbol = symbol.infer.ini,
num_hidden = 96,
arg.params = model$arg.params,
aux.params = model$aux.params,
init_states = NULL,
ctx = ctx)
pred_prob <- mx.nd.slice.axis(infer[[1]], axis=0, begin = infer_length-1, end = infer_length)
pred <- mx.nd.argmax(data = pred_prob, axis = 1, keepdims = T)
predict <- c(predict, as.numeric(as.array(pred)))
for (i in 1:100) {
infer <- mx.infer.rnn.one.unroll(infer.data = pred,
symbol = symbol.infer,
num_hidden = 96,
arg.params = model$arg.params,
aux.params = model$aux.params,
init_states = infer[-1],
ctx = ctx)
pred <- mx.nd.argmax(data = infer$loss_output, axis = 1, keepdims = T)
predict <- c(predict, as.numeric(as.array(pred)))
}
predict_txt <- paste0(rev_dic[as.character(predict)], collapse = "")
predict_txt_tot <- paste0(infer_raw, predict_txt, collapse = "")
Generated sequence: The United States are the street the problem that the problem that the problem that the problem that the problem that the
Key ideas appear somewhat overemphasized.
Noise is now inserted in the predictions by sampling each character based on their modeled probability.
set.seed(44)
infer_raw <- c("The United States are")
infer_split <- dic[strsplit(infer_raw, '') %>% unlist]
infer_length <- length(infer_split)
predict <- numeric()
infer <- mx.infer.rnn.one.unroll(infer.data = data,
symbol = symbol.infer.ini,
num_hidden = 96,
arg.params = model$arg.params,
aux.params = model$aux.params,
init_states = NULL,
ctx = ctx)
pred_prob <- as.numeric(as.array(mx.nd.slice.axis(
infer[[1]], axis=0, begin = infer_length-1, end = infer_length)))
pred <- sample(length(pred_prob), prob = pred_prob, size = 1) - 1
predict <- c(predict, pred)
for (i in 1:100) {
infer <- mx.infer.rnn.one.unroll(infer.data = mx.nd.array(array(pred, dim = c(1,1))),
symbol = symbol.infer,
num_hidden = 96,
arg.params = model$arg.params,
aux.params = model$aux.params,
init_states = infer[-1],
ctx = ctx)
pred_prob <- as.numeric(as.array(infer[[1]]))
pred <- sample(length(pred_prob), prob = pred_prob, size = 1, replace = T) - 1
predict <- c(predict, pred)
}
predict_txt <- paste0(rev_dic[as.character(predict)], collapse = "")
predict_txt_tot <- paste0(infer_raw, predict_txt, collapse = "")
Generated sequence: The United States are not all the politics and accusies wutter of our faith. Let willing idea that improve maving with the
Now we get a more alembicated political speech.