# HW2 Deep Learning in Genomics
"""
Background: Hi-C data can be used to identify the 3D conformation of DNA in cells 
by highlighting which regions of DNA are located close together. The 3D conformation 
of DNA can be determined more accurately when higher resolution Hi-C data is available.

Data: The provided dataset was used by Zhang et al. for the development of HiCPlus.

Tasks:
    
1. Implement a convolutional neural network (CNN) to transform patches of 
   low-resolution Hi-C data into patches of high-resolution Hi-C data.
   
    a) Implement remove_borders() to properly modify the training labels
   
    b) Implement train_model() to train your model using a cross-validation scheme 
       with a mean squared error loss function
      
    c) Determine the mean squared error and Pearson correlation of the final 
       model on the entire training set
      
2. For several matrix patches in the training set, visualize and compare the 
   training input, training label, and predicted label. 
   
    a) Implement make_prediction() to generate model predictions. If you are
       using keras, you may find the model.predict() function helpful.
      
    b) Produce the visualizations. You may find matplotlib's
       imshow() function and "Reds" colormap useful for this part.
   
3. For several matrix patches in the test set, visualize and compare the 
   training input and predicted label.
   
    a) Use make_prediction() to generate model predictions.
    
    b) Produce the visualizations.
"""

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

def get_data():
    train_x = np.load("GM12878_replicatedown16_chr19_22.npy").astype("float32")
    train_y = np.load("GM12878_replicate_original_chr19_22.npy").astype("float32")
    test_x = np.load("GM12878_replicate_down16_chr17_17.npy").astype("float32")
    train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[2], train_x.shape[3], 1))
    train_y = np.reshape(train_y, (train_y.shape[0], train_y.shape[2], train_y.shape[3], 1))
    test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[2], test_x.shape[3], 1))
    return train_x, train_y, test_x

def remove_borders(train_y, border_size):
    """
    param train_y: a numpy array of training labels with dimensions (samples, 40, 40, 1)
    param border_size: the size of the border to be removed
    return: a numpy array of training labels with dimensions (samples, 40-(2*border_size), 40-(2*border_size), 1),
            where border_size is chosen based on the model architecture to ensure the training labels
            and model outputs have the same size
    """
    # TODO
    pass

def train_model(train_x, train_y):
    """
    Implements and trains the model using a cross-validation scheme with MSE loss
    param train_x: the training inputs
    param train_y: the training labels
    return: a trained model
    """
    # TODO
    pass

def make_prediction(model, input_data):
    """
    param model: a trained model
    param input_data: model inputs
    return: the model's predictions for the provided input data
    """
    # TODO
    pass
        
def main():
    # TODO
    
    # Call get_data() to read in all of the data
    
    # Call remove_borders() to properly modify the training labels
    
    # Call train_model() to train the model
    
    # Visualize several of the training and test matrix patches
    
    pass

if __name__ == '__main__':
    main()