Seq2Seq
Published on April 19, 2020
Tensorflow Addon
バグがある? https://github.com/tensorflow/tensorflow/issues/20067
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
Example
#ENCODER
class EncoderNetwork(tf.keras.Model):
def __init__(self,input_vocab_size,embedding_dims, rnn_units ):
super().__init__()
self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,
output_dim=embedding_dims)
self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True,
return_state=True )
#DECODER
class DecoderNetwork(tf.keras.Model):
def __init__(self,output_vocab_size, embedding_dims, rnn_units):
super().__init__()
self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,
output_dim=embedding_dims)
self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
# Sampler
self.sampler = tfa.seq2seq.sampler.TrainingSampler()
# Create attention mechanism with memory = None
self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
self.rnn_cell = self.build_rnn_cell(BATCH_SIZE)
self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler= self.sampler,
output_layer=self.dense_layer)
def build_attention_mechanism(self, units,memory, memory_sequence_length):
return tfa.seq2seq.LuongAttention(units, memory = memory,
memory_sequence_length=memory_sequence_length)
#return tfa.seq2seq.BahdanauAttention(units, memory = memory, memory_sequence_length=memory_sequence_length)
# wrap decodernn cell
def build_rnn_cell(self, batch_size ):
rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,
attention_layer_size=dense_units)
return rnn_cell
def build_decoder_initial_state(self, batch_size, encoder_state,Dtype):
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size,
dtype = Dtype)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
return decoder_initial_state
encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)
decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)
def loss_function(y_pred, y):
#shape of y [batch_size, ty]
#shape of y_pred [batch_size, Ty, output_vocab_size]
sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
reduction='none')
loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)
#skip loss calculation for padding sequences i.e. y = 0
#[ <start>,How, are, you, today, 0, 0, 0, 0 ....<end>]
#[ 1, 234, 3, 423, 3344, 0, 0 ,0 ,0, 2 ]
# y is a tensor of [batch_size,Ty] . Create a mask when [y=0]
# mask the loss when padding sequence appears in the output sequence
mask = tf.logical_not(tf.math.equal(y,0)) #output 0 for y=0 else output 1
mask = tf.cast(mask, dtype=loss.dtype)
loss = mask* loss
loss = tf.reduce_mean(loss)
return loss
def train_step(input_batch, output_batch,encoder_initial_cell_state):
#initialize loss = 0
loss = 0
with tf.GradientTape() as tape:
encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)
a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp,
initial_state =encoder_initial_cell_state)
#[last step activations,last memory_state] of encoder passed as input to decoder Network
# Prepare correct Decoder input & output sequence data
decoder_input = output_batch[:,:-1] # ignore <end>
#compare logits with timestepped +1 version of decoder_input
decoder_output = output_batch[:,1:] #ignore <start>
# Decoder Embeddings
decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)
#Setting up decoder memory from encoder output and Zero State for AttentionWrapperState
decoderNetwork.attention_mechanism.setup_memory(a)
decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,
encoder_state=[a_tx, c_tx],
Dtype=tf.float32)
#BasicDecoderOutput
outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state=decoder_initial_state,
sequence_length=BATCH_SIZE*[Ty-1])
logits = outputs.rnn_output
#Calculate loss
loss = loss_function(logits, decoder_output)
#Returns the list of all layer variables / weights.
variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables
# differentiate loss wrt variables
gradients = tape.gradient(loss, variables)
#grads_and_vars – List of(gradient, variable) pairs.
grads_and_vars = zip(gradients,variables)
optimizer.apply_gradients(grads_and_vars)
return loss
#RNN LSTM hidden and memory state initializer
def initialize_initial_state():
return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]
epochs = 15
for i in range(1, epochs+1):
encoder_initial_cell_state = initialize_initial_state()
total_loss = 0.0
for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)
total_loss += batch_loss
if (batch+1)%20 == 0:
print("total loss: {} epoch {} batch {} ".format(batch_loss.numpy(), i, batch+1))
checkpoint.save(file_prefix = chkpoint_prefix)
If you like it, share it!