#!/usr/bin/env python3 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from matplotlib.colors import LogNorm from sklearn.metrics import mean_squared_error # Daniel Lersch (dlersch@jlab.org) # Simple script to train a variational autoencoder to generate y=x^2 data # Note: Most of the lines here have been taken from the great tensorflow / keras tutorials, see details here: # Tensorflow: https://www.tensorflow.org/tutorials/generative/cvae?hl=en # Keras: https://keras.io/examples/generative/vae/ print(" ") print("Generate y=x^2 Data with a Variational Autoencoder Neural Network") print(" ") K = tf.keras np.random.seed(123) # Create the data set: print("Create y=x^2 data...") n_channels = 50 x = np.linspace(-5.0,5.0,n_channels) n_events = 100 rel_noise = 0.05 sig_data = (np.ones((n_events,1))*x*x) * np.random.normal(1.0,rel_noise,(n_events,n_channels)) bkg_data = (np.ones((n_events,1))*x*x*x*0.1) * np.random.normal(1.0,rel_noise,(n_events,n_channels)) print("...done!") print(" ") #Set up the ae-model: print("Set up variational autoencoder model and define helper functions...") latent_dim = 4 n_training_epochs = 300 coding_activation = 'linear' output_activation = 'linear' layer_activation = 'relu' encoder = K.Sequential([ K.layers.Input(shape=(n_channels,)), K.layers.Dense(100,activation=layer_activation), K.layers.Dense(50,activation=layer_activation), K.layers.Dense(25,activation=layer_activation), K.layers.Dense(latent_dim+latent_dim,activation=coding_activation) ]) decoder = K.Sequential([ K.layers.Input(shape=(latent_dim,)), K.layers.Dense(25,activation=layer_activation), K.layers.Dense(50,activation=layer_activation), K.layers.Dense(100,activation=layer_activation), K.layers.Dense(n_channels,activation=output_activation) ]) vae_input = K.layers.Input(shape=(n_channels,)) # Split the encoder output (2* latent dimension) into mean and sigma: # In fact, log(sigma) is used, for stability reasons def get_vae_encoder_response(x): mean, logvar = tf.split(encoder(x), num_or_size_splits=2, axis=1) variation = tf.random.normal(shape=tf.shape(mean)) * tf.exp(logvar * .5) + mean return [mean,logvar,variation] _, _, variation = get_vae_encoder_response(vae_input) vae_output = decoder(variation) vae_model = K.Model(vae_input,vae_output) vae_model.compile(optimizer=K.optimizers.Adam(1e-2)) mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) # Use this function to calculate the latent loss: def log_normal_pdf(sample,mean, logvar, raxis=1): log2pi = tf.math.log(2. * np.pi) return tf.reduce_sum( -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), axis=raxis) # Compute the entire VAE loss: def vae_loss(data): z_mean, z_log_var, coding = get_vae_encoder_response(data) latent_loss = log_normal_pdf(coding, 0., 0.) - log_normal_pdf(coding, z_mean, z_log_var) decoder_out = decoder(coding) rec_loss = -mse(data,decoder_out) return -tf.reduce_mean(rec_loss + latent_loss) training_loss = K.metrics.Mean(name="training_loss") # Update the weights, using a gradient... @tf.function def train_step(data): with tf.GradientTape() as tape: current_loss = vae_loss(data) gradients = tape.gradient(current_loss, vae_model.trainable_variables) vae_model.optimizer.apply_gradients(zip(gradients, vae_model.trainable_variables)) training_loss.update_state(current_loss) return training_loss.result() print("...done!") print(" ") #Train the variational autoencoder vae_training_curve = [] #++++++++++++++++++++++++++++ for epoch in range(1,n_training_epochs+1): tr_err = train_step(sig_data) vae_training_curve.append(K.backend.get_value(tr_err)) #++++++++++++++++++++++++++++ random_noise = tf.random.normal(shape=(n_events,latent_dim)) gen_data = decoder(random_noise) print("Visualize results...") plt.rcParams.update({'font.size': 30}) fig_t,ax_t = plt.subplots(figsize=(10,8)) ax_t.plot(vae_training_curve,'-',linewidth=2.0) ax_t.set_xlabel('Epoch') ax_t.set_ylabel('Loss') fig_t.savefig('vae_training_curve.png') plt.close(fig_t) fig_res,ax_res = plt.subplots(1,2,figsize=(20,8)) ax_res[0].hist(np.mean(random_noise,1),10) ax_res[0].set_xlabel('Random Normal Values') ax_res[0].set_ylabel('Counts') #++++++++++++++++++++++++++++++++++++ for ev in range(n_events): if ev == 0: ax_res[1].plot(x,sig_data[ev,:],'ko',label='Sig. Data') else: ax_res[1].plot(x,sig_data[ev,:],'ko') #++++++++++++++++++++++++++++++++++++ ax_res[1].plot(x,np.mean(gen_data,0),"r-",linewidth=3.0,label='Avg. Gen. Data') ax_res[1].set_xlabel('x') ax_res[1].set_ylabel('y') ax_res[1].legend() fig_res.savefig('vae_predictions.png') plt.close(fig_res) print("...done! Have a great day!") print(" ")