# HW3 Deep Learning in Genomics
'''
Task:
1) Train an Autoencoder to learn the distribution of RNAseq counts
    i) Use latent embedding sizes of 5, 10, 50, and 100. This is the size of the output of the encoder, and the input to the decoder.
        Also, follow the implementation notes in the assignment.py comments.

    ii) Report and compare their reconstruction Mean Squared Error (MSE). How does the size of the latent space affect the reconstruction MSE?

2) Analyze the learned encodings
    We have provided code to produce a PCA & t-SNE plot of the raw RNAseq counts
    i) Adapt the provided code to make PCA & t-SNE plots of the autoencoder reconstructions

    ii) Compare and report the plots of the original data with the plots of the reconstructions.
        How do the PCA & t-SNE plots of the reconstructions compare with those of the original data?

    iii) Take the trained encoders of the previous autoencoders and make PCA & t-SNE plots of the latent embedding vectors.
         Remember to make and submit plots for all 4 autoencoders (latent embedding sizes - 5, 10, 50, and 100).
         How does the size of the latent space affect the latent vector plots?

    iv) Compare and report the plots of the original data and the plots of the latent vectors.
        How do the various embedding sizes change the quality of clustering?


Bonus:
Convert your auto-encoder to a de-noising auto-encoder. (Reconstructions are still trained to be the un-noised input).
i) Add Gaussian noise to the input data and train the denoising autoencoder to remove the added artifical noise.
ii) Add negative binomial distribution (zero-inflated negative binomial) noise to the input data and train the denoising autoencoder to remove the added artificial noise.
iii) For each of the above, compare and report the PCA plots of the original data and the plots of the latent vectors.
     Is the de-noising autoencoder able to remove the artificial noise in your data?

Food for thought (Not even a bonus):
    Why do you think the PCA & t-SNE plots vary so much in how useful of a clustering they provide over
    the data representations you have worked with in this assignment?
'''
import os
import sys
import numpy as np
import argparse

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# import torch
# from torch import nn
# from torch.utils.data import DataLoader

# import tensorflow as tf
# from tensorflow import keras


'''
Implementation Notes:
    Make sure to use a couple dropout layers to help prevent overfitting on this small dataset. No validation set is necessary for this task.
    For AE's, a good rule of thumb is to have the Encoder's architecture be symmetric to the Decoder's. (Ex. Enc=1000->100->10, Dec=10->100->1000)
    This assignment was tested with ~200 epochs. Your network may take more or less to produce some good clusters.

    For best results, you need to mask the outputs of the autoencoder that correspond to 0's (dropped counts) in the input.
    The autoencoder learns best when it is not forced to also learn to reconstruct the zeros in the input. This technique is known as gradient masking,
    and is commonly used in situations like this to prevent a training signal to reconstruct unimportant data.

    To acheive this, the reconstruction output of the autoencoder must be multiplied by a mask that zeros the corresponding
    dropped counts of the inputs, which ensures the gradient from that particular output is zero. Use MSE loss, and only average over
    the number of non-zero elements in the batch.
    You have been provided with a helper method in order to do so that you may adapt or use as is.

    Feel free to experiment with the size of the network, hyperparameters, and plotting parameters. More training information is better information!
'''

# Terminology Note: Latent vector = Latent Representation = Encoding = Output of Encoder

parser = argparse.ArgumentParser()
parser.add_argument('-epochs', type=int, default=200)
parser.add_argument('-batch_size', type=int, default=128)
parser.add_argument('-lr', type=float, default=1e-4)
parser.add_argument('-latent_size', type=int, default=10)
args = parser.parse_args()

num_epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
latent_size = args.latent_size


# Helper method as an example to return the output mask
def return_out_mask(batch_data: np.array) -> np.array:
    '''
    Input: NxM np.array of counts data
    Output: NxM np.array filled with 1 corresponding to non-zero inputs
            and 0 corresponding to dropped inputs
    '''
    zero_indices = np.nonzero(batch_data == 0)
    mask = np.ones_like(batch_data)
    zeros = np.zeros_like(batch_data)
    mask[zero_indices] = zeros[zero_indices] # Fills in 0's into appropriate indices
    return mask


# Load Data
dataset = np.load('counts.npy')
labels = np.loadtxt("labels.txt")

# Pytorch Dataloader
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Define Model, Optimizer, and Loss



# Training Loop




# Save model weights (Optional)



# Obtain Encoded Representations & Reconstructions
#   Run just the trained encoder across dataset to obtain latent vectors



# T-SNE & PCA Plot of Counts Data
tsne_data = TSNE(n_components=2).fit_transform(dataset)
plt.scatter(tsne_data[:,0],tsne_data[:,1],c=labels,s=3)
plt.savefig('data_tsne.png')
plt.close()

pca = PCA(n_components=2)
pca_data = pca.fit_transform(dataset)
plt.scatter(pca_data[:,0],pca_data[:,1],c=labels,s=3)
plt.savefig('data_pca.png')
plt.close() # Always remember to close the plot!

# T-SNE & PCA Plot of Encodings
encoded_data = ? # np.array

tsne_latent = TSNE(n_components=2).fit_transform(encoded_data)
plt.scatter(tsne_latent[:,0],tsne_latent[:,1],c=labels,s=3)
plt.savefig('encoded_tsne.png')
plt.close()

pca = PCA(n_components=2)
pca_data = pca.fit_transform(encoded_data)
plt.scatter(pca_data[:,0],pca_data[:,1],c=labels,s=3)
plt.savefig('encoded_pca.png')
plt.close()

# T-SNE & PCA Plot of Reconstructions
reconstructions = ? # np.array

tsne_latent = TSNE(n_components=2).fit_transform(reconstructions)
plt.scatter(tsne_latent[:,0],tsne_latent[:,1],c=labels,s=3)
plt.savefig('reconstructed_tsne.png')
plt.close()

pca = PCA(n_components=2)
pca_data = pca.fit_transform(reconstructions)
plt.scatter(pca_data[:,0],pca_data[:,1],c=labels,s=3)
plt.savefig('reconstructed_pca.png')
plt.close()
