Introduction to CREsted with Peak Regression#
In this introductory notebook, we will train a peak regression model on the mouse BICCN data and inspect the results to get a feel for the capabilities of the CREsted package.
Import Data#
For this tutorial, we will use the mouse BICCN dataset which is available in the get_dataset()
function.
To train a CREsted peak regression model on your data, you need:
A consensus regions BED file containing all the regions of interest accross cell types.
A folder containing the bigwig files per cell type. Each file should be named according to the cell type: {cell type name}.bw.
A genome fasta file and optionally a chromosome sizes file.
You could use a tool like SnapATAC2 to generate the consensus regions and bigwig files from your own data.
from pathlib import Path
import numpy as np
%matplotlib inline
import matplotlib
from sklearn.metrics import pairwise
import anndata as ad
import keras
import crested
# Set the font type to ensure text is saved as whole words
matplotlib.rcParams["pdf.fonttype"] = 42 # Use TrueType fonts instead of Type 3 fonts
matplotlib.rcParams["ps.fonttype"] = 42 # For PostScript as well, if needed
Download the tutorial data.
For this tutorial we will be training on the ‘cut sites’, but the ‘coverage’ data are also available (an older version of the tutorial would train on the coverage).
bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
Or if you have the data already available:
bigwigs_folder = "../../../../mouse/biccn/bigwigs/fragments_bws/"
regions_file = "../../../..//mouse/biccn/consensus_peaks_inputs.bed"
chromsizes_file = "../../../..//mouse/biccn/mm.chrom.sizes"
By loading our genome in the crested.Genome
class and setting it with register_genome()
, the genome is automatically used in all functions throughout CREsted. If you don’t provide the chromomsome sizes, they will be automatically calculated from the fasta.
Note
Any function or class that expects a genome object can still accept a genome object as explicit input even if one was already registered. In that case, the input will be used instead of the registered genome.
# Set the genome
genome_dir = Path("../../../../mouse/biccn/")
genome = crested.Genome(genome_dir / "mm10.fa", genome_dir / "mm10.chrom.sizes")
crested.register_genome(genome)
2024-12-12T10:18:24.298297+0100 INFO Genome mm10 registered.
We can use the import_bigwigs()
function to import bigwigs per cell type and a consensus regions BED file into an anndata.AnnData
object,
with the imported cell types as the AnnData.obs
and the consensus peak regions as the AnnData.var
.
Optionally, provide a chromsizes file to filter out regions that are not within the chromsizes.
adata = crested.import_bigwigs(
bigwigs_folder=bigwigs_folder,
regions_file=regions_file,
target_region_width=1000, # optionally, use a different width than the consensus regions file (500bp) for the .X values calculation
target="count", # or "max", "mean", "logcount" --> what we will be predicting
)
adata
2024-12-12T10:18:30.706855+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
obs: 'file_path'
var: 'chr', 'start', 'end'
To train a model, we always need to add a split column to our dataset, which we can do using crested.pp.train_val_test_split()
.
This will add a column to the AnnData.obs
with the split type for each region (train, val, or test).
# Choose the chromosomes for the validation and test sets
crested.pp.train_val_test_split(
adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
)
# Alternatively, We can split randomly on the regions
# crested.pp.train_val_test_split(
# adata, strategy="region", val_size=0.1, test_size=0.1, random_state=42
# )
print(adata.var["split"].value_counts())
adata.var
split
train 440993
val 56064
test 49936
Name: count, dtype: int64
chr | start | end | split | |
---|---|---|---|---|
region | ||||
chr1:3093998-3096112 | chr1 | 3093998 | 3096112 | train |
chr1:3094663-3096777 | chr1 | 3094663 | 3096777 | train |
chr1:3111367-3113481 | chr1 | 3111367 | 3113481 | train |
chr1:3112727-3114841 | chr1 | 3112727 | 3114841 | train |
chr1:3118939-3121053 | chr1 | 3118939 | 3121053 | train |
... | ... | ... | ... | ... |
chrX:169878506-169880620 | chrX | 169878506 | 169880620 | train |
chrX:169879374-169881488 | chrX | 169879374 | 169881488 | train |
chrX:169924670-169926784 | chrX | 169924670 | 169926784 | train |
chrX:169947743-169949857 | chrX | 169947743 | 169949857 | train |
chrX:169950171-169952285 | chrX | 169950171 | 169952285 | train |
546993 rows × 4 columns
Preprocessing#
Region Width#
For this example we’re interested in training on wider regions than our consensus regions file (500bp) to also include some sequence information from the tails of our peaks.
We change it to 2114 bp regions since that is what the Chrombpnet architecture was originally trained on and that’s what we’ll be using. This is not fixed and can be adapted to what you prefer, as long as it is compatible with the model architecture.
Wider regions will mean that you don’t only include sequence information from the center of the peaks and could effectively increase your dataset size if the tails of the peak include meaningful information, but could also introduce noise if the tails are not informative.
Wider regions will also increase the computational cost of training the model.
crested.pp.change_regions_width(
adata,
2114,
) # change the adata width of the regions to 2114bp
Peak Normalization#
Additionally, we can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks()
function.
This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.
crested.pp.normalize_peaks(
adata, top_k_percent=0.03
) # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01
2024-12-12T10:19:01.511397+0100 INFO Filtering on top k Gini scores...
2024-12-12T10:19:05.058074+0100 INFO Added normalization weights to adata.obsm['weights']...
chr | start | end | split | |
---|---|---|---|---|
region | ||||
chr5:76656624-76658738 | chr5 | 76656624 | 76658738 | train |
chr13:30900787-30902901 | chr13 | 30900787 | 30902901 | train |
chr9:65586049-65588163 | chr9 | 65586049 | 65588163 | test |
chr9:65586556-65588670 | chr9 | 65586556 | 65588670 | test |
chr9:65587095-65589209 | chr9 | 65587095 | 65589209 | test |
... | ... | ... | ... | ... |
chr9:65459289-65461403 | chr9 | 65459289 | 65461403 | test |
chr9:65459852-65461966 | chr9 | 65459852 | 65461966 | test |
chr5:76587680-76589794 | chr5 | 76587680 | 76589794 | train |
chr9:65523082-65525196 | chr9 | 65523082 | 65525196 | test |
chr19:18337654-18339768 | chr19 | 18337654 | 18339768 | train |
48308 rows × 4 columns
We can visualize the normalization factor for each cell type using the crested.pl.bar.normalization_weights()
function to inspect which cell type peaks were up/down weighted.
crested.pl.bar.normalization_weights(
adata, title="Normalization Weights per Cell Type", x_label_rotation=90
)
There is no single best way to preprocess your data, so we recommend experimenting with different preprocessing steps to see what works best for your data.
Likewise there is no single best training approach, so we recommend experimenting with different training strategies.
# Save the final preprocessing results
adata.write_h5ad("mouse_cortex.h5ad")
Model Training#
The entire CREsted workflow is built around the crested.tl.Crested()
class.
Everything that requires a model (training, evaluation, prediction) is done through this class.
This class has a couple of required arguments:
data
: thecrested.tl.data.AnnDataModule
object containing all the data (anndata, genome) and dataloaders that specify how to load the data.model
: thekeras.Model
object containing the model architecture.config
: thecrested.tl.TaskConfig
object containing the optimizer, loss function, and metrics to use in training.
Generally you wouldn’t run these steps in a notebook, but rather in a script or a python file so you could run it on a cluster or in the background.
Data#
We’ll start by initializing the crested.tl.data.AnnDataModule
object with our data.
This will tell our model how to load the data and what data to load during fitting/evaluation.
The main arguments to supply are the adata
object, the genome
file path, and the batch_size
.
Other optional arguments are related to the training data loading (e.g. shuffling, whether to load the sequences into memory, …).
The genome file you need to provide yourself as this is not included in the crested package.
# read in your preprocessed data
adata = ad.read_h5ad("mouse_cortex.h5ad")
datamodule = crested.tl.data.AnnDataModule(
adata,
batch_size=256, # lower this if you encounter OOM errors
max_stochastic_shift=3, # optional augmentation
always_reverse_complement=True, # default True. Will double the effective size of the training dataset.
)
Model definition#
Next, we’ll define the model architecture. This is a standard Keras model definition, so you can provide your own model definition if you like.
Alternatively, there are a couple of ready-to-use models available in the crested.tl.zoo
module.
Each of them require the width of the input sequences and the number of output classes (your Anndata.obs
) as arguments.
# Load chrombpnet architecture for a dataset with 2114bp regions and 19 cell types
model_architecture = crested.tl.zoo.chrombpnet(
seq_len=2114, num_classes=len(list(adata.obs_names))
)
TaskConfig#
The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our ‘task’).
Some default configurations are available for some common tasks such as ‘topic_classification’ and ‘peak_regression’,
which you can load using the crested.tl.default_configs()
function.
# Load the default configuration for training a peak regression model
config = crested.tl.default_configs(
"peak_regression"
) # or "topic_classification" for topic classification
print(config)
# If you want to change some small parameters to an existing config, you can do it like this
# For example, the default learning rate is 0.001, but you can change it to 0.0001
# config.optimizer.learning_rate = 0.0001
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1ace2510>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1ae592d0>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
Alternatively, you can create your own TaskConfig object and specify the optimizer, loss function, and metrics yourself if you want to do something completely custom.
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
keras.metrics.MeanAbsoluteError(),
keras.metrics.MeanSquaredError(),
keras.metrics.CosineSimilarity(axis=1),
crested.tl.metrics.PearsonCorrelation(),
crested.tl.metrics.ConcordanceCorrelationCoefficient(),
crested.tl.metrics.PearsonCorrelationLog(),
crested.tl.metrics.ZeroPenaltyMetric(),
]
alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1ad14410>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1ad14a10>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
Training#
Now we’re ready to train our model.
We’ll create a Crested
object with the data, model, and config objects we just created.
Then, we can call the fit()
method to train the model.
Read the documentation for more information on all available arguments to customize your training (e.g. augmentations, early stopping, checkpointing, …).
By default:
The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.
Every best model is saved based on the validation loss.
The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.
Note
If you specify the same project_name and run_name as a previous run, then CREsted will assume that you want to continue training and will load the last available model checkpoint from the {project_name}/{run_name} folder and continue from that epoch.
# setup the trainer
trainer = crested.tl.Crested(
data=datamodule,
model=model_architecture,
config=alternative_config,
project_name="mouse_biccn", # change to your liking
run_name="basemodel", # change to your liking
logger="wandb", # or 'wandb', 'tensorboard'
seed=7, # For reproducibility
)
2024-12-12T10:19:08.451626+0100 WARNING Output directory mouse_biccn/basemodel/checkpoints already exists. Will continue training from epoch 18.
# train the model
trainer.fit(
epochs=60,
learning_rate_reduce_patience=3,
early_stopping_patience=6,
)
View project at: https://wandb.ai/kemp/mouse_biccn
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20241209_201427-t1jx172q/logs
/data/projects/c04/cbd-saerts/nkemp/software/CREsted/docs/tutorials/wandb/run-20241209_201453-jciwxmaz
2024-12-09T20:14:57.994147+0100 WARNING Model does not have an optimizer. Please compile the model before training.
None
2024-12-09T20:14:58.159088+0100 INFO Loading sequences into memory...
2024-12-09T20:15:05.095428+0100 INFO Loading sequences into memory...
Epoch 1/60
3445/3446 ━━━━━━━━━━━━━━━━━━━━ 0s 223ms/step - concordance_correlation_coefficient: 0.5750 - cosine_similarity: 0.8166 - loss: -0.3449 - mean_absolute_error: 3.2040 - mean_squared_error: 45.8918 - pearson_correlation: 0.6825 - pearson_correlation_log: 0.5407 - zero_penalty_metric: 1136.6290
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 833s 232ms/step - concordance_correlation_coefficient: 0.5750 - cosine_similarity: 0.8166 - loss: -0.3449 - mean_absolute_error: 3.2039 - mean_squared_error: 45.8899 - pearson_correlation: 0.6825 - pearson_correlation_log: 0.5407 - zero_penalty_metric: 1136.6272 - val_concordance_correlation_coefficient: 0.4512 - val_cosine_similarity: 0.6731 - val_loss: 0.5619 - val_mean_absolute_error: 9.3106 - val_mean_squared_error: 198.5345 - val_pearson_correlation: 0.5837 - val_pearson_correlation_log: 0.4806 - val_zero_penalty_metric: 908.3172 - learning_rate: 0.0010
Epoch 2/60
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 783s 227ms/step - concordance_correlation_coefficient: 0.6789 - cosine_similarity: 0.8460 - loss: -0.4591 - mean_absolute_error: 2.8753 - mean_squared_error: 36.1717 - pearson_correlation: 0.7676 - pearson_correlation_log: 0.6078 - zero_penalty_metric: 1115.6266 - val_concordance_correlation_coefficient: 0.6851 - val_cosine_similarity: 0.8106 - val_loss: -0.3190 - val_mean_absolute_error: 3.2301 - val_mean_squared_error: 40.4674 - val_pearson_correlation: 0.7418 - val_pearson_correlation_log: 0.5651 - val_zero_penalty_metric: 1070.3921 - learning_rate: 0.0010
Epoch 3/60
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 783s 227ms/step - concordance_correlation_coefficient: 0.7293 - cosine_similarity: 0.8592 - loss: -0.4970 - mean_absolute_error: 2.7146 - mean_squared_error: 31.4138 - pearson_correlation: 0.8026 - pearson_correlation_log: 0.6242 - zero_penalty_metric: 1111.0103 - val_concordance_correlation_coefficient: 0.7375 - val_cosine_similarity: 0.8421 - val_loss: -0.3265 - val_mean_absolute_error: 3.3975 - val_mean_squared_error: 35.2583 - val_pearson_correlation: 0.7845 - val_pearson_correlation_log: 0.5985 - val_zero_penalty_metric: 1198.2017 - learning_rate: 0.0010
Epoch 4/60
2191/3446 ━━━━━━━━━━━━━━━━━━━━ 4:40 223ms/step - concordance_correlation_coefficient: 0.7562 - cosine_similarity: 0.8663 - loss: -0.5173 - mean_absolute_error: 2.6118 - mean_squared_error: 28.6017 - pearson_correlation: 0.8187 - pearson_correlation_log: 0.6330 - zero_penalty_metric: 1106.6896
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ sequence │ (None, 2114, 4) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv1d (Conv1D) │ (None, 2114, 512) │ 10,240 │ sequence[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ batch_normalization │ (None, 2114, 512) │ 2,048 │ conv1d[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ activation │ (None, 2114, 512) │ 0 │ batch_normalizat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout (Dropout) │ (None, 2114, 512) │ 0 │ activation[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_1conv │ (None, 2110, 512) │ 786,432 │ dropout[0][0] │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_1bn │ (None, 2110, 512) │ 2,048 │ bpnet_1conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_1activation │ (None, 2110, 512) │ 0 │ bpnet_1bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_1crop │ (None, 2110, 512) │ 0 │ dropout[0][0] │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add (Add) │ (None, 2110, 512) │ 0 │ bpnet_1activatio… │ │ │ │ │ bpnet_1crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_1dropout │ (None, 2110, 512) │ 0 │ add[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_2conv │ (None, 2102, 512) │ 786,432 │ bpnet_1dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_2bn │ (None, 2102, 512) │ 2,048 │ bpnet_2conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_2activation │ (None, 2102, 512) │ 0 │ bpnet_2bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_2crop │ (None, 2102, 512) │ 0 │ bpnet_1dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_1 (Add) │ (None, 2102, 512) │ 0 │ bpnet_2activatio… │ │ │ │ │ bpnet_2crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_2dropout │ (None, 2102, 512) │ 0 │ add_1[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_3conv │ (None, 2086, 512) │ 786,432 │ bpnet_2dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_3bn │ (None, 2086, 512) │ 2,048 │ bpnet_3conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_3activation │ (None, 2086, 512) │ 0 │ bpnet_3bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_3crop │ (None, 2086, 512) │ 0 │ bpnet_2dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_2 (Add) │ (None, 2086, 512) │ 0 │ bpnet_3activatio… │ │ │ │ │ bpnet_3crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_3dropout │ (None, 2086, 512) │ 0 │ add_2[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_4conv │ (None, 2054, 512) │ 786,432 │ bpnet_3dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_4bn │ (None, 2054, 512) │ 2,048 │ bpnet_4conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_4activation │ (None, 2054, 512) │ 0 │ bpnet_4bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_4crop │ (None, 2054, 512) │ 0 │ bpnet_3dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_3 (Add) │ (None, 2054, 512) │ 0 │ bpnet_4activatio… │ │ │ │ │ bpnet_4crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_4dropout │ (None, 2054, 512) │ 0 │ add_3[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_5conv │ (None, 1990, 512) │ 786,432 │ bpnet_4dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_5bn │ (None, 1990, 512) │ 2,048 │ bpnet_5conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_5activation │ (None, 1990, 512) │ 0 │ bpnet_5bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_5crop │ (None, 1990, 512) │ 0 │ bpnet_4dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_4 (Add) │ (None, 1990, 512) │ 0 │ bpnet_5activatio… │ │ │ │ │ bpnet_5crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_5dropout │ (None, 1990, 512) │ 0 │ add_4[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_6conv │ (None, 1862, 512) │ 786,432 │ bpnet_5dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_6bn │ (None, 1862, 512) │ 2,048 │ bpnet_6conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_6activation │ (None, 1862, 512) │ 0 │ bpnet_6bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_6crop │ (None, 1862, 512) │ 0 │ bpnet_5dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_5 (Add) │ (None, 1862, 512) │ 0 │ bpnet_6activatio… │ │ │ │ │ bpnet_6crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_6dropout │ (None, 1862, 512) │ 0 │ add_5[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_7conv │ (None, 1606, 512) │ 786,432 │ bpnet_6dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_7bn │ (None, 1606, 512) │ 2,048 │ bpnet_7conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_7activation │ (None, 1606, 512) │ 0 │ bpnet_7bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_7crop │ (None, 1606, 512) │ 0 │ bpnet_6dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_6 (Add) │ (None, 1606, 512) │ 0 │ bpnet_7activatio… │ │ │ │ │ bpnet_7crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_7dropout │ (None, 1606, 512) │ 0 │ add_6[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_8conv │ (None, 1094, 512) │ 786,432 │ bpnet_7dropout[0… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_8bn │ (None, 1094, 512) │ 2,048 │ bpnet_8conv[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_8activation │ (None, 1094, 512) │ 0 │ bpnet_8bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_8crop │ (None, 1094, 512) │ 0 │ bpnet_7dropout[0… │ │ (Cropping1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_7 (Add) │ (None, 1094, 512) │ 0 │ bpnet_8activatio… │ │ │ │ │ bpnet_8crop[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bpnet_8dropout │ (None, 1094, 512) │ 0 │ add_7[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ global_average_poo… │ (None, 512) │ 0 │ bpnet_8dropout[0… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_out (Dense) │ (None, 19) │ 9,747 │ global_average_p… │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 6,329,875 (24.15 MB)
Trainable params: 6,320,659 (24.11 MB)
Non-trainable params: 9,216 (36.00 KB)
Finetuning on cell type-specific regions#
Subsetting the consensuspeak set#
For peak regression models, we recommend to continue training the model trained on all consensuspeaks on a subset of cell type-specific regions. Since we are interested in understanding the enhancer code uniquely identifying the cell types in the dataset, finetuning on specific regions will allow us to approach that. We define specific regions as regions with a high Gini index, indicating that their peak distribution over all cell types will be skewed and specific for one or more cell types.
Read the documentation of the crested.pp.filter_regions_on_specificity()
function for more information on how the filtering is done.
crested.pp.filter_regions_on_specificity(
adata, gini_std_threshold=1.0
) # All regions with a Gini index 1 std above the mean across all regions will be kept
adata
2024-12-12T10:19:15.431415+0100 INFO After specificity filtering, kept 91475 out of 546993 regions.
AnnData object with n_obs × n_vars = 19 × 91475
obs: 'file_path'
var: 'chr', 'start', 'end', 'split'
obsm: 'weights'
adata.write_h5ad("mouse_cortex_filtered.h5ad")
Loading the pretrained model on all consensuspeaks and finetuning with lower learning rate#
datamodule = crested.tl.data.AnnDataModule(
adata,
batch_size=64, # Recommended to go for a smaller batch size than in the basemodel
max_stochastic_shift=3,
always_reverse_complement=True,
)
# First load the pretrained model on all peaks
model_architecture = keras.models.load_model(
"mouse_biccn/basemodel/checkpoints/17.keras",
compile=False, # Choose the basemodel with best validation loss/performance metrics
)
# Use the same config you used for the pretrained model. EXCEPT THE LEARNING RATE, make sure that is lower than it was on the epoch you select the model from.
optimizer = keras.optimizers.Adam(learning_rate=1e-4) # Lower LR!
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
keras.metrics.MeanAbsoluteError(),
keras.metrics.MeanSquaredError(),
keras.metrics.CosineSimilarity(axis=1),
crested.tl.metrics.PearsonCorrelation(),
crested.tl.metrics.ConcordanceCorrelationCoefficient(),
crested.tl.metrics.PearsonCorrelationLog(),
crested.tl.metrics.ZeroPenaltyMetric(),
]
alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1665c410>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1adafc90>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
# setup the trainer
trainer = crested.tl.Crested(
data=datamodule,
model=model_architecture,
config=alternative_config,
project_name="mouse_biccn", # change to your liking
run_name="finetuned_model", # change to your liking
logger="wandb", # or 'wandb', 'tensorboard'
)
2024-12-12T10:19:16.196631+0100 WARNING Output directory mouse_biccn/finetuned_model/checkpoints already exists. Will continue training from epoch 9.
trainer.fit(
epochs=60,
learning_rate_reduce_patience=3,
early_stopping_patience=6,
)
Evaluate the model#
After training, we can evaluate the model on the test set using the test()
method.
If we’re still in the same session, we can simply continue using the same object.
If not, we can load the model from disk using theload_model()
method.
This means that we have to create a new Crested
object first.
However, this time, since the taskconfig and architecture are saved in the .keras file, we only have to provide our datamodule.
adata = ad.read_h5ad("mouse_cortex_filtered.h5ad")
datamodule = crested.tl.data.AnnDataModule(
adata,
batch_size=256, # lower this if you encounter OOM errors
)
# load an existing model
evaluator = crested.tl.Crested(data=datamodule)
evaluator.load_model(
"mouse_biccn/finetuned_model/checkpoints/02.keras", # Load your model
compile=True,
)
If you experimented with many different hyperparameters for your model, chances are that you will start overfitting on your validation dataset.
It’s therefore always a good idea to evaluate your model on the test set after getting good results on your validation data to see how well it generalizes to unseen data.
# evaluate the model on the test set
evaluator.test()
3/33 ━━━━━━━━━━━━━━━━━━━━ 1s 60ms/step - concordance_correlation_coefficient: 0.6685 - cosine_similarity: 0.8724 - loss: 0.5137 - mean_absolute_error: 1.6471 - mean_squared_error: 13.7273 - pearson_correlation: 0.7443 - pearson_correlation_log: 0.6123 - zero_penalty_metric: 2443.7920
33/33 ━━━━━━━━━━━━━━━━━━━━ 14s 116ms/step - concordance_correlation_coefficient: 0.7091 - cosine_similarity: 0.8744 - loss: 0.5054 - mean_absolute_error: 1.6253 - mean_squared_error: 13.6106 - pearson_correlation: 0.7706 - pearson_correlation_log: 0.6340 - zero_penalty_metric: 2470.7397
2024-12-12T10:20:51.416834+0100 INFO Test concordance_correlation_coefficient: 0.7221
2024-12-12T10:20:51.417376+0100 INFO Test cosine_similarity: 0.8707
2024-12-12T10:20:51.417635+0100 INFO Test loss: 0.5097
2024-12-12T10:20:51.417896+0100 INFO Test mean_absolute_error: 1.6882
2024-12-12T10:20:51.418280+0100 INFO Test mean_squared_error: 14.4946
2024-12-12T10:20:51.418600+0100 INFO Test pearson_correlation: 0.7813
2024-12-12T10:20:51.418892+0100 INFO Test pearson_correlation_log: 0.6345
2024-12-12T10:20:51.419171+0100 INFO Test zero_penalty_metric: 2302.9001
Predict#
After training, we can also use the predict()
method to predict the labels for new data and add them as a layer to the AnnData
object.
A common use case is to compare the predicted labels to the true labels for multiple trained models to see how well they compare.
We can initiate a new Crested object (if you have different data) or use the existing one.
Here we continue with the existing one since we’ll use the same data as we trained on.
# add predictions for model checkpoint to the adata
evaluator.predict(
adata, model_name="biccn_model"
) # adds the predictions to the adata.layers["biccn_model"]
357/358 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step
358/358 ━━━━━━━━━━━━━━━━━━━━ 25s 66ms/step
2024-12-12T10:21:16.424896+0100 INFO Adding predictions to anndata.layers[biccn_model].
adata.layers
Layers with keys: biccn_model
If you don’t want to predict on the entire dataset, you can also predict on a given sequence or region using the predict_sequence()
or predict_regions()
methods.
Many of the plotting functions in the crested.pl
module can be used to visualize these model predictions.
Example predictions on test set regions#
It is always interesting to see how the model performs on unseen test set regions. It is recommended to always look at a few examples to spot potential biases, or trends that you do not expect.
# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df
chr | start | end | split | |
---|---|---|---|---|
region | ||||
chr18:3269690-3271804 | chr18 | 3269690 | 3271804 | test |
chr18:3350307-3352421 | chr18 | 3350307 | 3352421 | test |
chr18:3451398-3453512 | chr18 | 3451398 | 3453512 | test |
chr18:3463977-3466091 | chr18 | 3463977 | 3466091 | test |
chr18:3488308-3490422 | chr18 | 3488308 | 3490422 | test |
... | ... | ... | ... | ... |
chr9:124125533-124127647 | chr9 | 124125533 | 124127647 | test |
chr9:124140961-124143075 | chr9 | 124140961 | 124143075 | test |
chr9:124142793-124144907 | chr9 | 124142793 | 124144907 | test |
chr9:124477280-124479394 | chr9 | 124477280 | 124479394 | test |
chr9:124479548-124481662 | chr9 | 124479548 | 124481662 | test |
8198 rows × 4 columns
# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
print(region)
crested.pl.bar.region_predictions(adata, region, title="Predictions vs Ground Truth")
chr18:3892771-3894885
2024-12-12T10:21:16.459431+0100 INFO Plotting bar plots for region: chr18:3892771-3894885, models: ['biccn_model']
Example predictions on manually defined regions#
chrom = "chr3" #'chr18'
start = 72535878 - 807 # 61107770
end = 72536378 + 807 # 61109884
sequence = genome.fetch(chrom, start, end).upper()
prediction = evaluator.predict_sequence(sequence)
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 863ms/step
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
Prediction on gene locus#
We can also score a gene locus by using a sliding window over a predefined genomic range. We can compare those predictions then to the bigwig we did the predictions for, to see if the profile matches the CREsted predictions.
chrom = "chr4"
start = 91209533
end = 91374781
scores, coordinates, min_loc, max_loc, tss_position = evaluator.score_gene_locus(
chr_name=chrom,
gene_start=start,
gene_end=end,
class_name="Sst",
strand="-",
upstream=50000,
downstream=10000,
step_size=100,
)
bigwig = "../../../../mouse/biccn/bigwigs/bws/Sst.bw"
bw_values, midpoints = crested.utils.extract_bigwig_values_per_bp(bigwig, coordinates)
2024-12-12T10:21:22.649225+0100 WARNING extract_bigwig_values_per_bp() is deprecated. Please use crested.utils.read_bigwig_region(bw_file, (chr, start, end)) instead.
%matplotlib inline
crested.pl.hist.locus_scoring(
scores,
(min_loc, max_loc),
gene_start=start,
gene_end=end,
title="CREsted prediction around Elavl2 gene locus for Sst",
bigwig_values=bw_values,
bigwig_midpoints=midpoints,
)
Model performance on the entire test set#
After looking at specific instances, now we can look at the model performance on a larger scale.
First, we can check per cell type/class the correlation of predictions and peak heights over the peaks in the test set.
adata.layers
Layers with keys: biccn_model
classn = "L2_3IT"
crested.pl.scatter.class_density(
adata,
class_name=classn,
model_names=["biccn_model"],
split="test",
log_transform=True,
width=5,
height=5,
)
2024-12-12T10:21:23.275696+0100 INFO Plotting density scatter for class: L2_3IT, models: ['biccn_model'], split: test
To now check the correlations between all classes, we can plot a heatmap to assess the model performance.
crested.pl.heatmap.correlations_predictions(
adata,
split="test",
title="Correlations between Groundtruths and Predictions",
x_label_rotation=90,
width=5,
height=5,
log_transform=True,
vmax=1,
vmin=-0.15,
)
2024-12-12T10:21:23.381532+0100 INFO Plotting heatmap correlations for split: test, models: ['biccn_model']
It is also recommended to compare this heatmap to the self correlation plot of the peaks themselves. If peaks between cell types are correlated, then it is expected that predictions from non-matching classes for correlationg cell types will also be high.
crested.pl.heatmap.correlations_self(
adata,
title="Self Correlation Heatmap",
x_label_rotation=90,
width=5,
height=5,
vmax=1,
vmin=-0.15,
)
Sequence contribution scores#
We can calculate the contribution scores for a sequence of interest using the calculate_contribution_scores_sequence()
method.
You always need to ensure that the sequence or region you provide is the same length as the model input (2114bp in our case).
Contribution scores on manually defined sequences#
# random sequence of length 2114bp as an example
sequence = "A" * 2114
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
sequence, class_names=["Astro", "Endo"]
) # focus on two cell types of interest
2024-12-12T10:21:24.012136+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).
Contribution scores on manually defined genomic regions#
Alternatively, you can calculate contribution scores for regions of interest using the calculate_contribution_scores_regions()
method.
These regions don’t have to be in your original dataset, as long as they exist in the genome file that you provided to the AnnDataModule
and they are the same length as the model input.
# focus on two cell types of interest
regions_of_interest = [
"chr18:61107770-61109884"
] # FIRE enhancer region, should only have motifs in Micro_PVM
classes_of_interest = ["Astro", "Micro_PVM", "L5ET"]
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
region_idx=regions_of_interest, class_names=classes_of_interest
)
2024-12-12T10:21:28.374958+0100 INFO Calculating contribution scores for 3 class(es) and 1 region(s).
Contribution scores for regions can be plotted using the crested.pl.patterns.contribution_scores()
function.
This will generate a plot per class per region.
%matplotlib inline
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels=regions_of_interest,
class_labels=classes_of_interest,
zoom_n_bases=500,
title="FIRE Enhancer Region",
) # zoom in on the center 500bp
Contribution scores on random test set regions#
# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
classes_of_interest = ["L2_3IT", "Pvalb"]
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
region_idx=region, class_names=classes_of_interest
)
2024-12-12T10:21:39.277930+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).
%matplotlib inline
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels=[region],
class_labels=classes_of_interest,
zoom_n_bases=500,
title="Test set region",
) # zoom in on the center 500bp
Enhancer design [more updates soon]#
Load data and model#
adata = ad.read_h5ad("mouse_cortex_filtered.h5ad")
datamodule = crested.tl.data.AnnDataModule(
adata,
)
# load an existing model
evaluator = crested.tl.Crested(data=datamodule)
evaluator.load_model(
"mouse_biccn/finetuned_model/checkpoints/02.keras", # Load your model
compile=True,
)
Sequence evolution#
We can create synthetic enhancers for a specified class using in silico evolution with the enhancer_design_in_silico_evolution()
method.
designed_sequences = evaluator.enhancer_design_in_silico_evolution(
target_class="L5ET", n_sequences=5, n_mutations=10, target_len=500
)
2024-12-12T10:21:46.923792+0100 INFO Loading sequences into memory...
prediction = evaluator.predict_sequence(designed_sequences[1])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 267ms/step
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
designed_sequences[0], class_names=["L5ET"]
) # focus on two cell types of interest
2024-12-12T10:22:52.478705+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels="",
class_labels=["L5ET"],
zoom_n_bases=500,
title="synth L5ET",
) # zoom in on the center 500bp
Motif embedding#
designed_sequences = evaluator.enhancer_design_motif_implementation(
patterns={
"SOX10": "AACAATGGCCCCATTGT",
"CREB5": "ATGACATCA",
},
target_class="Oligo",
n_sequences=5,
target_len=500,
)
prediction = evaluator.predict_sequence(designed_sequences[0])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
designed_sequences[0], class_names=["Oligo"]
) # focus on two cell types of interest
2024-12-12T10:23:01.302410+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
%matplotlib inline
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels="",
class_labels=["Oligo"],
zoom_n_bases=500,
title="synth Oligo",
)
L2 optimizer#
def L2_distance(
mutated_predictions: np.ndarray,
original_prediction: np.ndarray,
target: np.ndarray,
classes_of_interest: list[int],
):
"""Calculate the L2 distance between the mutated predictions and the target class"""
if len(original_prediction.shape) == 1:
original_prediction = original_prediction[None]
L2_sat_mut = pairwise.euclidean_distances(
mutated_predictions[:, classes_of_interest],
target[classes_of_interest].reshape(1, -1),
)
L2_baseline = pairwise.euclidean_distances(
original_prediction[:, classes_of_interest],
target[classes_of_interest].reshape(1, -1),
)
return np.argmax((L2_baseline - L2_sat_mut).squeeze())
L2_optimizer = crested.utils.EnhancerOptimizer(optimize_func=L2_distance)
target_cell_type = "L2_3IT"
classes_of_interest = [
i
for i, ct in enumerate(adata.obs_names)
if ct in ["L2_3IT", "L5ET", "L5IT", "L6IT"]
]
target = np.array([20 if x == target_cell_type else 0 for x in adata.obs_names])
intermediate, designed_sequences = evaluator.enhancer_design_in_silico_evolution(
target_class=None,
n_sequences=5,
n_mutations=30,
enhancer_optimizer=L2_optimizer,
target=target,
return_intermediate=True,
no_mutation_flanks=(807, 807),
classes_of_interest=classes_of_interest,
)
idx = 0
prediction = evaluator.predict_sequence(designed_sequences[idx])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
designed_sequences[idx], class_names=["L2_3IT", "L5ET", "L5IT", "L6IT"]
) # focus on two cell types of interest
2024-12-12T10:24:07.129049+0100 INFO Calculating contribution scores for 4 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels="",
class_labels=["L2_3IT", "L5ET", "L5IT", "L6IT"],
zoom_n_bases=500,
title="",
)
wandb: 🚀 View run finetuned_model at: https://wandb.ai/kemp/mouse_biccn/runs/it1js3u7
wandb: Find logs at: wandb/run-20241212_101919-it1js3u7/logs