Introduction to CREsted with Peak Regression#

In this introductory notebook, we will train a peak regression model on the mouse BICCN data and inspect the results to get a feel for the capabilities of the CREsted package.

Import Data#

For this tutorial, we will use the mouse BICCN dataset which is available in the get_dataset() function.
To train a CREsted peak regression model on your data, you need:

  1. A consensus regions BED file containing all the regions of interest accross cell types.

  2. A folder containing the bigwig files per cell type. Each file should be named according to the cell type: {cell type name}.bw.

  3. A genome fasta file and optionally a chromosome sizes file.

You could use a tool like SnapATAC2 to generate the consensus regions and bigwig files from your own data.

import anndata as ad
import crested
import numpy as np
import matplotlib

# Set the font type to ensure text is saved as whole words
matplotlib.rcParams["pdf.fonttype"] = 42  # Use TrueType fonts instead of Type 3 fonts
matplotlib.rcParams["ps.fonttype"] = 42  # For PostScript as well, if needed

Download the tutorial data.
For this tutorial we will be training on the ‘cut sites’, but the ‘coverage’ data are also available (an older version of the tutorial would train on the coverage). Data is originally from Zemke et al., Nature, 2023.

bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")

Or, if you have the data already available, you can specifiy the paths directly (here we created relative symlinks to our data).

bigwigs_folder = "data/mouse_biccn/bigwigs_cut_sites.tar.gz.untar"
regions_file = "data/mouse_biccn/consensus_peaks_biccn.bed"

By loading our genome in the crested.Genome class and setting it with register_genome(), the genome is automatically used in all functions throughout CREsted. If you don’t provide the chromomsome sizes, they will be automatically calculated from the fasta.

Note

Any function or class that expects a genome object can still accept a genome object as explicit input even if one was already registered. In that case, the input will be used instead of the registered genome.

## Set the genome
#genome = crested.Genome(
#    "data/genomes/mm10/mm10.fa", "data/genomes/mm10/mm10.chrom.sizes"
#)

genome = crested.Genome(
    "../../../../mouse/biccn/mm10.fa", "../../../../mouse/biccn/mm10.chrom.sizes"
)
crested.register_genome(
    genome
)  # Register the genome so that it can be used by the package

print(genome.fetch("chr1", 10000000, 10000010))
2025-03-21T12:41:04.859727+0100 INFO Genome mm10 registered.
TTTTCAATGC

We can use the import_bigwigs() function to import bigwigs per cell type and a consensus regions BED file into an anndata.AnnData object, with the imported cell types as the AnnData.obs and the consensus peak regions as the AnnData.var.

adata = crested.import_bigwigs(
    bigwigs_folder=bigwigs_folder,
    regions_file=regions_file,
    target_region_width=1000,  # optionally, use a different width than the consensus regions file (500bp) for the .X values calculation
    target="count",  # or "max", "mean", "logcount" --> what we will be predicting
)
adata
2025-03-21T12:41:14.984164+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
    obs: 'file_path'
    var: 'chr', 'start', 'end'

To train a model, we always need to add a split column to our dataset, which we can do using crested.pp.train_val_test_split().
This will add a column to the AnnData.var with the split type for each region (train, val, or test).

# Choose the chromosomes for the validation and test sets
crested.pp.train_val_test_split(
    adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
)

# Alternatively, We can split randomly on the regions
# crested.pp.train_val_test_split(
#     adata, strategy="region", val_size=0.1, test_size=0.1, random_state=42
# )

print(adata.var["split"].value_counts())
adata.var.head(3)
split
train    440993
val       56064
test      49936
Name: count, dtype: int64
chr start end split
region
chr1:3094805-3095305 chr1 3094805 3095305 train
chr1:3095470-3095970 chr1 3095470 3095970 train
chr1:3112174-3112674 chr1 3112174 3112674 train

Preprocessing#

Region Width#

For this example we’re interested in training on wider regions than our consensus regions file (500bp) to also include some sequence information from the tails of our peaks.

We change it to 2114 bp regions since that is what the Chrombpnet architecture was originally trained on and that’s what we’ll be using. This is not fixed and can be adapted to what you prefer, as long as it is compatible with the model architecture.

Wider regions will mean that you don’t only include sequence information from the center of the peaks and could effectively increase your dataset size if the tails of the peak include meaningful information, but could also introduce noise if the tails are not informative.
Wider regions will also increase the computational cost of training the model.

crested.pp.change_regions_width(
    adata,
    2114,
)  # change the adata width of the regions to 2114bp

Peak Normalization#

Additionally, we can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks() function.

This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.

crested.pp.normalize_peaks(
    adata, top_k_percent=0.03
)  # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01
2025-03-21T12:42:16.417618+0100 INFO Filtering on top k Gini scores...
2025-03-21T12:42:19.969232+0100 INFO Added normalization weights to adata.obsm['weights']...
chr start end split
region
chr5:76656624-76658738 chr5 76656624 76658738 train
chr13:30900787-30902901 chr13 30900787 30902901 train
chr9:65586049-65588163 chr9 65586049 65588163 test
chr9:65586556-65588670 chr9 65586556 65588670 test
chr9:65587095-65589209 chr9 65587095 65589209 test
... ... ... ... ...
chr9:65459289-65461403 chr9 65459289 65461403 test
chr9:65459852-65461966 chr9 65459852 65461966 test
chr5:76587680-76589794 chr5 76587680 76589794 train
chr9:65523082-65525196 chr9 65523082 65525196 test
chr19:18337654-18339768 chr19 18337654 18339768 train

48308 rows × 4 columns

We can visualize the normalization factor for each cell type using the crested.pl.bar.normalization_weights() function to inspect which cell type peaks were up/down weighted.

%matplotlib inline
crested.pl.bar.normalization_weights(
    adata, title="Normalization Weights per Cell Type", x_label_rotation=90
)
../_images/7f4cf2858c36e760d033e8a06fa6a3a63dbe5cb54c025d95c61d4c390c2347a0.png

There is no single best way to preprocess your data, so we recommend experimenting with different preprocessing steps to see what works best for your data.
Likewise there is no single best training approach, so we recommend experimenting with different training strategies.

# Save the final preprocessing results
adata.write_h5ad("data/mouse_cortex.h5ad")

Model Training#

The entire CREsted training workflow is built around the crested.tl.Crested() class. Everything that requires both a model and a dataloader (training, evaluation) is done through this class.
This class has a couple of required arguments:

  • data: the crested.tl.data.AnnDataModule object containing all the data (anndata, genome) and dataloaders that specify how to load the data.

  • model: the keras.Model object containing the model architecture.

  • config: the crested.tl.TaskConfig object containing the optimizer, loss function, and metrics to use in training.

Generally you wouldn’t run these steps in a notebook, but rather in a script or a python file so you could run it on a cluster or in the background.

Data#

We’ll start by initializing the crested.tl.data.AnnDataModule object with our data.
This will tell our model how to load the data and what data to load during fitting/evaluation. The main arguments to supply are the adata object, the genome object (if you didn’t register one), and the batch_size.
Other optional arguments are related to the training data loading (e.g. shuffling, whether to load the sequences into memory, …).

# read in your preprocessed data
adata = ad.read_h5ad("data/mouse_cortex.h5ad")
datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=256,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional data augmentation to slightly reduce overfitting
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

Model definition#

Next, we’ll define the model architecture. This is a standard Keras model definition, so you can provide your own model definition if you like.
Alternatively, there are a couple of ready-to-use models available in the crested.tl.zoo module.
Each of them require the width of the input sequences and the number of output classes (your Anndata.obs) as arguments.

# Load chrombpnet-like architecture for a dataset with 2114bp regions and 19 cell types
model_architecture = crested.tl.zoo.dilated_cnn(
    seq_len=2114, num_classes=len(list(adata.obs_names))
)

TaskConfig#

The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our ‘task’).
Some default configurations are available for some common tasks such as ‘topic_classification’ and ‘peak_regression’, which you can load using the crested.tl.default_configs() function.

# Load the default configuration for training a peak regression model
config = crested.tl.default_configs(
    "peak_regression"
)  # or "topic_classification" for topic classification
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7fc1c6ed43d0>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7fc1c5844f50>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Alternatively, you can create your own TaskConfig object and specify the optimizer, loss function, and metrics yourself if you want to do something completely custom.

import keras

# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7fc1c82cec90>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7fc1c5854f50>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Training#

Now we’re ready to train our model. We’ll create a Crested object with the data, model, and config objects we just created.
Then, we can call the fit() method to train the model.
Read the documentation for more information on all available arguments to customize your training (e.g. augmentations, early stopping, checkpointing, …).

By default:

  1. The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.

  2. Every best model is saved based on the validation loss.

  3. The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.

Note

If you specify the same project_name and run_name as a previous run, then CREsted will assume that you want to continue training and will load the last available model checkpoint from the {project_name}/{run_name} folder and continue from that epoch.

# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger="wandb",  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)
2025-01-30T10:27:00.875863+0100 WARNING Model does not have an optimizer. Please compile the model before training.
None
2025-01-30T10:27:01.153422+0100 INFO Loading sequences into memory...
2025-01-30T10:27:08.883584+0100 INFO Loading sequences into memory...
Epoch 1/60
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ sequence            │ (None, 2114, 4)   │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d (Conv1D)     │ (None, 2114, 512) │     10,240 │ sequence[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalization │ (None, 2114, 512) │      2,048 │ conv1d[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation          │ (None, 2114, 512) │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout (Dropout)   │ (None, 2114, 512) │          0 │ activation[0][0]  │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1conv         │ (None, 2110, 512) │    786,432 │ dropout[0][0]     │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1bn           │ (None, 2110, 512) │      2,048 │ bpnet_1conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1activation   │ (None, 2110, 512) │          0 │ bpnet_1bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1crop         │ (None, 2110, 512) │          0 │ dropout[0][0]     │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 2110, 512) │          0 │ bpnet_1activatio… │
│                     │                   │            │ bpnet_1crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1dropout      │ (None, 2110, 512) │          0 │ add[0][0]         │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2conv         │ (None, 2102, 512) │    786,432 │ bpnet_1dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2bn           │ (None, 2102, 512) │      2,048 │ bpnet_2conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2activation   │ (None, 2102, 512) │          0 │ bpnet_2bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2crop         │ (None, 2102, 512) │          0 │ bpnet_1dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 2102, 512) │          0 │ bpnet_2activatio… │
│                     │                   │            │ bpnet_2crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2dropout      │ (None, 2102, 512) │          0 │ add_1[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3conv         │ (None, 2086, 512) │    786,432 │ bpnet_2dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3bn           │ (None, 2086, 512) │      2,048 │ bpnet_3conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3activation   │ (None, 2086, 512) │          0 │ bpnet_3bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3crop         │ (None, 2086, 512) │          0 │ bpnet_2dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 2086, 512) │          0 │ bpnet_3activatio… │
│                     │                   │            │ bpnet_3crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3dropout      │ (None, 2086, 512) │          0 │ add_2[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4conv         │ (None, 2054, 512) │    786,432 │ bpnet_3dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4bn           │ (None, 2054, 512) │      2,048 │ bpnet_4conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4activation   │ (None, 2054, 512) │          0 │ bpnet_4bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4crop         │ (None, 2054, 512) │          0 │ bpnet_3dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 2054, 512) │          0 │ bpnet_4activatio… │
│                     │                   │            │ bpnet_4crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4dropout      │ (None, 2054, 512) │          0 │ add_3[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5conv         │ (None, 1990, 512) │    786,432 │ bpnet_4dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5bn           │ (None, 1990, 512) │      2,048 │ bpnet_5conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5activation   │ (None, 1990, 512) │          0 │ bpnet_5bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5crop         │ (None, 1990, 512) │          0 │ bpnet_4dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 1990, 512) │          0 │ bpnet_5activatio… │
│                     │                   │            │ bpnet_5crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5dropout      │ (None, 1990, 512) │          0 │ add_4[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6conv         │ (None, 1862, 512) │    786,432 │ bpnet_5dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6bn           │ (None, 1862, 512) │      2,048 │ bpnet_6conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6activation   │ (None, 1862, 512) │          0 │ bpnet_6bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6crop         │ (None, 1862, 512) │          0 │ bpnet_5dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 1862, 512) │          0 │ bpnet_6activatio… │
│                     │                   │            │ bpnet_6crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6dropout      │ (None, 1862, 512) │          0 │ add_5[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7conv         │ (None, 1606, 512) │    786,432 │ bpnet_6dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7bn           │ (None, 1606, 512) │      2,048 │ bpnet_7conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7activation   │ (None, 1606, 512) │          0 │ bpnet_7bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7crop         │ (None, 1606, 512) │          0 │ bpnet_6dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 1606, 512) │          0 │ bpnet_7activatio… │
│                     │                   │            │ bpnet_7crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7dropout      │ (None, 1606, 512) │          0 │ add_6[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8conv         │ (None, 1094, 512) │    786,432 │ bpnet_7dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8bn           │ (None, 1094, 512) │      2,048 │ bpnet_8conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8activation   │ (None, 1094, 512) │          0 │ bpnet_8bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8crop         │ (None, 1094, 512) │          0 │ bpnet_7dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 1094, 512) │          0 │ bpnet_8activatio… │
│                     │                   │            │ bpnet_8crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8dropout      │ (None, 1094, 512) │          0 │ add_7[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 512)       │          0 │ bpnet_8dropout[0… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │      9,747 │ global_average_p… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 6,329,875 (24.15 MB)
 Trainable params: 6,320,659 (24.11 MB)
 Non-trainable params: 9,216 (36.00 KB)

Finetuning on cell type-specific regions#

Subsetting the consensuspeak set#

For peak regression models, we recommend to continue training the model trained on all consensuspeaks on a subset of cell type-specific regions. Since we are interested in understanding the enhancer code uniquely identifying the cell types in the dataset, finetuning on specific regions will allow us to approach that. We define specific regions as regions with a high Gini index, indicating that their peak distribution over all cell types will be skewed and specific for one or more cell types.

Read the documentation of the crested.pp.filter_regions_on_specificity() function for more information on how the filtering is done.

crested.pp.filter_regions_on_specificity(
    adata, gini_std_threshold=1.0
)  # All regions with a Gini index 1 std above the mean across all regions will be kept
adata
2025-03-21T12:42:57.421730+0100 INFO After specificity filtering, kept 91475 out of 546993 regions.
AnnData object with n_obs × n_vars = 19 × 91475
    obs: 'file_path'
    var: 'chr', 'start', 'end', 'split'
    obsm: 'weights'
adata.write_h5ad("data/mouse_cortex_filtered.h5ad")

Loading the pretrained model on all consensuspeaks and finetuning with lower learning rate#

datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=64,  # Recommended to go for a smaller batch size than in the basemodel
    max_stochastic_shift=3,
    always_reverse_complement=True,
)
# First load the pretrained model on all peaks
model_architecture = keras.models.load_model(
    "mouse_biccn/basemodel/checkpoints/17.keras",
    compile=False,  # Choose the basemodel with best validation loss/performance metrics
)
# Use the same config you used for the pretrained model. EXCEPT THE LEARNING RATE, make sure that is lower than it was on the epoch you select the model from.
optimizer = keras.optimizers.Adam(learning_rate=1e-4)  # Lower LR!
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1665c410>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1adafc90>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="finetuned_model",  # change to your liking
    logger="wandb",  # or 'wandb', 'tensorboard'
)
2024-12-12T10:19:16.196631+0100 WARNING Output directory mouse_biccn/finetuned_model/checkpoints already exists. Will continue training from epoch 9.
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)

Evaluate the model#

After training, we can evaluate the model on the test set using the test() method.
If we’re still in the same session, we can simply continue using the same object.
If not, we can load the model from disk using theload_model() method. This means that we have to create a new Crested object first.
However, this time, since the taskconfig and architecture are saved in the .keras file, we only have to provide our datamodule.

adata = ad.read_h5ad("data/mouse_cortex_filtered.h5ad")

datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=256,  # lower this if you encounter OOM errors
)
# load an existing model
evaluator = crested.tl.Crested(data=datamodule)
model_path = "mouse_biccn/finetuned_model/checkpoints/02.keras"

evaluator.load_model(
    model_path,
    compile=True,
)

If you experimented with many different hyperparameters for your model, chances are that you will start overfitting on your validation dataset.
It’s therefore always a good idea to evaluate your model on the test set after getting good results on your validation data to see how well it generalizes to unseen data.

# evaluate the model on the test set
evaluator.test()
 3/33 ━━━━━━━━━━━━━━━━━━━ 1s 60ms/step - concordance_correlation_coefficient: 0.6685 - cosine_similarity: 0.8724 - loss: 0.5137 - mean_absolute_error: 1.6471 - mean_squared_error: 13.7273 - pearson_correlation: 0.7443 - pearson_correlation_log: 0.6123 - zero_penalty_metric: 2443.7920
33/33 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - concordance_correlation_coefficient: 0.7091 - cosine_similarity: 0.8744 - loss: 0.5054 - mean_absolute_error: 1.6253 - mean_squared_error: 13.6106 - pearson_correlation: 0.7706 - pearson_correlation_log: 0.6340 - zero_penalty_metric: 2470.7397
2025-03-21T12:47:02.457872+0100 INFO Test concordance_correlation_coefficient: 0.7221
2025-03-21T12:47:02.458273+0100 INFO Test cosine_similarity: 0.8707
2025-03-21T12:47:02.458478+0100 INFO Test loss: 0.5097
2025-03-21T12:47:02.458680+0100 INFO Test mean_absolute_error: 1.6882
2025-03-21T12:47:02.458863+0100 INFO Test mean_squared_error: 14.4946
2025-03-21T12:47:02.459042+0100 INFO Test pearson_correlation: 0.7813
2025-03-21T12:47:02.459442+0100 INFO Test pearson_correlation_log: 0.6345
2025-03-21T12:47:02.459599+0100 INFO Test zero_penalty_metric: 2302.9001

Predict#

Now we have a trained model, we can use the crested.tl toolkit to run inference and explain our results. All the functionality shown below only expects a trained .keras model, meaning that you can use these functions with any model trained outside of the crested framework too.

Warning

An older version of crested used to handle all this functionality inside the Crested class. For ease of use, we refactored these methods to a functional form as shown below. You can still use them in its old object-method manner, but they are considered deprecated and will only be updated in its functional form.

The core function that you will be using is the predict() function. This expects as input something you want to predict over as well as trained model. You can even provide a list of models, as long as they expect the same input and output shapes. In that case the predictions will be averaged, which can be useful to make your predictions more robust.

Since crested is build around making predictions over genomic sequences, this can accept as input:

  • (lists of) sequence(s)

  • (lists of) genomic region name(s)

  • one hot encoded sequences of shape (N, L, 4)

  • anndata objects with regions as its .var index

Crested will convert these inputs to its required format for the model.
If your input is a region name or anndata, you should provide a genome as well if you did not register one.

First, we need to load a model. If you followed the tutorial you can load that one. If not, crested has a ‘model repository’ with commonly used models which you can download with get_model(). You can find all example models here.

# load a trained model
import keras

model = keras.models.load_model(model_path, compile=False)  # change to your model path
# store predictions for all our regions in the anndata object for later inspection.
predictions = crested.tl.predict(adata, model)
adata.layers["biccn_model"] = predictions.T  # adata expects (C, N) instead of (N, C)
91475/91475 ━━━━━━━━━━━━━━━━━━━━ 62s 666us/step

Many of the plotting functions in the crested.pl module can be used to visualize these model predictions.

Example predictions on test set regions#

It is always interesting to see how the model performs on unseen test set regions. It is recommended to always look at a few examples to spot potential biases, or trends that you do not expect.

# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df
chr start end split
region
chr18:3269690-3271804 chr18 3269690 3271804 test
chr18:3350307-3352421 chr18 3350307 3352421 test
chr18:3451398-3453512 chr18 3451398 3453512 test
chr18:3463977-3466091 chr18 3463977 3466091 test
chr18:3488308-3490422 chr18 3488308 3490422 test
... ... ... ... ...
chr9:124125533-124127647 chr9 124125533 124127647 test
chr9:124140961-124143075 chr9 124140961 124143075 test
chr9:124142793-124144907 chr9 124142793 124144907 test
chr9:124477280-124479394 chr9 124477280 124479394 test
chr9:124479548-124481662 chr9 124479548 124481662 test

8198 rows × 4 columns

%matplotlib inline
# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
print(region)
crested.pl.bar.region_predictions(adata, region, title="Predictions vs Ground Truth")
chr18:3892771-3894885
2025-03-21T12:49:28.718221+0100 INFO Plotting bar plots for region: chr18:3892771-3894885, models: ['biccn_model']
../_images/94cc413dd1b18a22ae4e333e810e9e1a6e4b5d0021866a866262abd3d1dcc4cb.png
Example predictions on manually defined regions#
chrom = "chr3"  #'chr18'
start = 72535878 - 807  # 61107770
end = 72536378 + 807  # 61109884

sequence = genome.fetch(chrom, start, end)

prediction = crested.tl.predict(sequence, model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 32ms/step
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
../_images/205e31a22d27028f934660a20adcf1ca1f2af9eefd5a34f299c6c8fc0fa0c3f4.png

Prediction on gene locus#

We can also score a gene locus by using a sliding window over a predefined genomic range. We can compare those predictions then to the BigWig we did the predictions for, to see if the profile matches the CREsted predictions.

chrom = "chr18" # Unseen chromosome
start= 59175401
end= 59410446

cell_type = "L6CT"
class_idx = list(adata.obs_names).index(cell_type)

upstream=50000
downstream=25000

strand= '+'

scores, coordinates, min_loc, max_loc, tss_position = crested.tl.score_gene_locus(
    chr_name=chrom,
    gene_start=start,
    gene_end=end,
    target_idx=class_idx,
    model=model,
    strand=strand,
    upstream=upstream,
    downstream=downstream,
    step_size=100,
)
3080/3080 ━━━━━━━━━━━━━━━━━━━━ 2s 659us/step
bigwig = bigwigs_folder +'/'+ cell_type + ".bw"

values = (
    crested.utils.read_bigwig_region(
        bigwig, (chrom, start - upstream, end + downstream)
    )
    if strand == "+"
    else crested.utils.read_bigwig_region(
        bigwig, (chrom, start - downstream, end + upstream)
    )
)
bw_values = values[0]
midpoints = values[1]

Note that here we compared pooled predictions over 1kb regions and cut-sites at 1 bp resolution from a cut-sites BigWig. To get a better comparison between the predicted and scATAC track, we recommend to compare against a coverage BigWig.

%matplotlib inline
crested.pl.hist.locus_scoring(
    scores,
    (min_loc, max_loc),
    gene_start=start,
    gene_end=end,
    title="CREsted prediction around Chsy3 gene for " + cell_type,
    bigwig_values=bw_values,
    bigwig_midpoints=midpoints,
    grid=False,
    figsize=(20, 5),
    marker_size=2,
    line_width=1,
)
../_images/557eab0d432da4797d96a105d99b64850389856c6cc03302507dd30c848f8a0f.png

Model performance on the entire test set#

After looking at specific instances, now we can look at the model performance on a larger scale.

First, we can check per cell type/class the correlation of predictions and peak heights over the peaks in the test set.

These plotting that show predictions over your entire dataset expect them to be saved in your anndata.layer object.

classn = "L2_3IT"
crested.pl.scatter.class_density(
    adata,
    class_name=classn,
    model_names=["biccn_model"],
    split="test",
    log_transform=True,
    width=5,
    height=5,
)
2025-03-21T12:53:19.434699+0100 INFO Plotting density scatter for class: L2_3IT, models: ['biccn_model'], split: test
../_images/40f282310f9522280013e257e5bf40faa8cfa50f5e21d754a57c507a4f5f2dc1.png

To now check the correlations between all classes, we can plot a heatmap to assess the model performance.

crested.pl.heatmap.correlations_predictions(
    adata,
    split="test",
    title="Correlations between Groundtruths and Predictions",
    x_label_rotation=90,
    width=5,
    height=5,
    log_transform=True,
    vmax=1,
    vmin=-0.15,
)
2025-03-21T12:53:20.885486+0100 INFO Plotting heatmap correlations for split: test, models: ['biccn_model']
../_images/94e6811f58fad96f766b0d20d9baf5d020cb138f85c99bac0a044452ab5a15ab.png

It is also recommended to compare this heatmap to the self correlation plot of the peaks themselves. If peaks between cell types are correlated, then it is expected that predictions from non-matching classes for correlationg cell types will also be high.

crested.pl.heatmap.correlations_self(
    adata,
    title="Self Correlation Heatmap",
    x_label_rotation=90,
    width=5,
    height=5,
    vmax=1,
    vmin=-0.15,
)
../_images/b706b008cfefed5870d1f04ae147d199f81ec9b79acc3671e9186c19996a2b4d.png

Sequence contribution scores#

We can calculate the contribution scores for a sequence of interest using the contribution_scores() function.

This will give us information on what nucleotides the model is looking at to make its prediction with respect to a specific output class.

You always need to ensure that the sequence or region you provide is the same length as the model input (2114bp in our case).

Here, similar to the predict function, you need some input (like a sequence or region name) and a (list of) model(s). If multiple models are provided, the contribution scores will be averaged.

Contribution scores on manually defined sequences#

# random sequence of length 2114bp as an example
sequence = "A" * 2114

# find the indices of the cell types of interest (Astro and Endo)
class_idx = list(adata.obs_names.get_indexer(["Astro", "Endo"]))

scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    sequence,
    target_idx=class_idx,  # None (=all classes), list of target indices, or empty list (='combined' class)
    model=model,
    method="expected_integrated_grad",  # default. Other options: "integrated_grad", "mutagenesis"
)
2025-02-04T10:56:07.119930+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).
print(scores.shape, one_hot_encoded_sequences.shape)
(1, 2, 2114, 4) (1, 2114, 4)

Contribution scores on manually defined genomic regions#

# similar example but with region names as input
regions_of_interest = [
    "chr18:61107770-61109884"
]  # FIRE enhancer region (Microglia enhancer)
classes_of_interest = ["Astro", "Micro_PVM"]
class_idx = list(adata.obs_names.get_indexer(classes_of_interest))

scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    regions_of_interest,
    target_idx=class_idx,
    model=model,
)
2025-02-04T10:56:44.549304+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).

Contribution scores for regions can be plotted using the crested.pl.patterns.contribution_scores() function.
This will generate a plot per class per region.

%matplotlib inline
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=regions_of_interest,
    class_labels=classes_of_interest,
    zoom_n_bases=500,
    title="FIRE Enhancer Region",
)  # zoom in on the center 500bp
../_images/2bc7dbf7d7df288fa6e082068841fd4038a136c1b7f7a44516d80d0cea865bee.png

Contribution scores on random test set regions#

# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
classes_of_interest = ["L2_3IT", "Pvalb"]
class_idx = list(adata.obs_names.get_indexer(classes_of_interest))
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    region,
    target_idx=class_idx,
    model=model,
)
2025-02-04T10:56:51.650879+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=[region],
    class_labels=classes_of_interest,
    zoom_n_bases=500,
    title="Test set region",
)  # zoom in on the center 500bp
../_images/34833a114752fb71c0bb16db3d47adc436e21aebe3573e75d95ebe8c4caabf8d.png

Enhancer design#

Enhancer design is an important concept in understanding a cell type’s cis regulatory code.
By designing sequences to be specifically accessible for a cell type and inspecting those designed sequences’ contribution score plots, we can get an understanding of which motifs are most important for that cell type’s enhancer code. Moreover, by inspecting intermediate results throughout the optimization process, we can see which motifs and which motif positions have a comparatively higher priority.

We follow the enhancer design process as described in this paper (Taskiran et al., Nature, 2024). We start from random sequences and select at each step the nucleotide mutation or motif implementation that will lead to the largest change in specific accessibility for a chosen cell type.

Sequence evolution#

The standard way of designing enhancers (by making single nucleotide mutations in randomly generated regions) can be carried out using enhancer_design_in_silico_evolution().

Before we start designing, we will calculate the nucleotide distribution of our consensus regions which will be used for creating random starting sequences (if you don’t do this the design function will assume a uniform distribution).

acgt_distribution = crested.utils.calculate_nucleotide_distribution(
    adata,  # accepts any sequence input, same as before
    per_position=True,  # return a distirbution per position in the sequence
)
acgt_distribution.shape
(2114, 4)
# we will design an enhancer for the L5ET cell type
class_idx = list(adata.obs_names).index("L5ET")

designed_sequences = crested.tl.enhancer_design_in_silico_evolution(
    model=model,
    target=class_idx,  # the default optimization function expects a target class index
    n_sequences=1,  # n enhancers to design
    n_mutations=10,  # n single nucleotide mutations to make per sequence
    target_len=500,  # only make mutations in the center 500bp
    acgt_distribution=acgt_distribution,  # if None, uniform distribution will be used
)
%matplotlib inline
# ensure that our designed sequence scores high on our target class
prediction = crested.tl.predict(designed_sequences[0], model=model)
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step
../_images/178ef1672b9df328dd2126ae44a7c45713a661a2b268219c3f925232b41220de.png
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    designed_sequences,
    model=model,
    target_idx=class_idx,
)
2025-02-04T10:59:08.807413+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["L5ET"],
    zoom_n_bases=500,
    title="synth L5ET",
    height=3,
)  # zoom in on the center 500bp
../_images/b2b1ca02b8669ba5824176a2908ae2080312f66eb8404b1418deda259a389842.png

Keep in mind that, if you start from random sequences, enhancer design will be non-deterministic so you won’t get the exact same results twice.

Inspecting the enhancer design process#

The enhancer design functions have a handy return_intermediate parameter which we can use to inspect at which point in the process which mutations are made. We can use some of crested plotting functions to visualize this process.

# same example as before, but now we will inspect the intermediate results
class_idx = list(adata.obs_names).index("L5ET")

(
    intermediate_results,
    designed_sequences,
) = crested.tl.enhancer_design_in_silico_evolution(
    model=model,
    target=class_idx,
    return_intermediate=True,  # set this to True now
    n_sequences=1,
    n_mutations=10,
    target_len=500,
    acgt_distribution=acgt_distribution,
)
print(intermediate_results[0].keys())
dict_keys(['inital_sequence', 'changes', 'predictions', 'designed_sequence'])
seq_idx = 0

# ensure that our designed sequence scores high on our target class
prediction = crested.tl.predict(designed_sequences[seq_idx], model=model)
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
../_images/d61603a6fe9fc5f4aefefe38db5028fc9193ef2f243465185652d9a3105fa180.png
# plot the evolution of predictions
%matplotlib inline
crested.pl.patterns.enhancer_design_steps_predictions(
    intermediate_results,
    target_classes=["L5ET"],
    obs_names=list(adata.obs_names),
    seperate=True,
    title="Synthetic L5ET evolution",
    width=5,
    height=5,
)
../_images/3d3a5fc838b91b61f914c32f03fe39bfa3d5273885817d2c87a5e38017f4b73b.png
# there's a utility function to extract the intermediate sequences from the dict
intermed_seqs = crested.utils.derive_intermediate_sequences(intermediate_results)

# calculate contribution scores as usual
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    intermed_seqs[seq_idx],
    target_idx=class_idx,
    model=model,
)
2025-02-04T11:15:19.754794+0100 INFO Calculating contribution scores for 1 class(es) and 11 region(s).
# you can use standard plotting to visualize the scores
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=[f"Step {i}" for i in range(len(intermed_seqs[0]))],
    zoom_n_bases=500,
    ylim=(-0.25, 2),  # best to keep a constant scale for comparison
)
../_images/353ed29d20ba8934c7ac1382409cb0904c04e94a609d40d2b76294987d074f68.png
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    designed_sequences[seq_idx],
    target_idx=class_idx,
    model=model,
)
2025-02-04T11:16:28.996021+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).

Motif insertion#

Another way of designing enhancers is by embedding known motifs into our sequences.
This way, you can investigate how specific motif combinations influence a sequence’s accessibility profile. For this, you can use the crested.tl.enhancer_design_motif_insertion() function. We can use the intermediate results to highlight the inserted motifs.

class_idx = list(adata.obs_names).index("Oligo")

intermediate_results, designed_sequences = crested.tl.enhancer_design_motif_insertion(
    patterns={
        "SOX10": "AACAATGGCCCCATTGT",
        "CREB5": "ATGACATCA",
    },
    target=class_idx,
    model=model,
    acgt_distribution=acgt_distribution,
    n_sequences=2,
    target_len=500,
    return_intermediate=True,
)
seq_idx = 0

# check whether our implanted motifs have the expected effect
prediction = crested.tl.predict(designed_sequences, model=model)
crested.pl.bar.prediction(
    prediction[seq_idx],
    classes=list(adata.obs_names),
    title="Synthetic Oligo prediction",
)
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step
../_images/b081803d524d954ba908a471c686f7521a175ebe6610f97d937f8f77f1df997d.png
intermediate_results[0]["changes"]
[(-1, 'N'), (874, 'AACAATGGCCCCATTGT'), (905, 'ATGACATCA')]
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    designed_sequences[seq_idx],
    target_idx=class_idx,
    model=model,
)
motif_positions = []
for motif in intermediate_results[seq_idx]["changes"]:
    motif_start = motif[0]
    motif_end = motif_start + len(motif[1])
    if motif_start != -1:
        motif_positions.append((motif_start, motif_end))
2025-02-04T11:07:46.382160+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
# see whether the model is actually using the implanted motifs
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["Oligo"],
    zoom_n_bases=500,
    title="synth Oligo",
    height=3,
    highlight_positions=motif_positions,
)
../_images/8b64aeb656889b0fefa042861c5804cf2ec1fa184edc1fb0f6a78e76811f518b.png

Using custom optimizers in enhancer design#

The default optimization function that is used in enhancer design is a weighted difference function that maximizes the increase in accessibility for a target cell type while penalizing an increase in accessibility for other cell types.

This is just one option though, many use cases exist where you might want to optimize for something different. For example, you could write an optimization function that maximizes the cosine similarity between a given accessibility vector and the designed sequence predicted accessibility vector.

Below we give an example on how to write such a custom optimization function, wherein we will try to reach a specific accessibily value for some target cell type relative to other related cell types using the L2 distance.

By default, the EnhancerOptimizer expects an optimization function as input that has the arguments mutated_predictions, original_predictions, and target, and returns the index of the best mutated sequence. See its documentation for more information.

from sklearn.metrics import pairwise


def L2_distance(
    mutated_predictions: np.ndarray,
    original_prediction: np.ndarray,
    target: np.ndarray,
    classes_of_interest: list[int],
):
    """Calculate the L2 distance between the mutated predictions and the target class"""
    if len(original_prediction.shape) == 1:
        original_prediction = original_prediction[None]
    L2_sat_mut = pairwise.euclidean_distances(
        mutated_predictions[:, classes_of_interest],
        target[classes_of_interest].reshape(1, -1),
    )
    L2_baseline = pairwise.euclidean_distances(
        original_prediction[:, classes_of_interest],
        target[classes_of_interest].reshape(1, -1),
    )
    return np.argmax((L2_baseline - L2_sat_mut).squeeze())


L2_optimizer = crested.utils.EnhancerOptimizer(optimize_func=L2_distance)

target_cell_type = "L2_3IT"

classes_of_interest = [
    i
    for i, ct in enumerate(adata.obs_names)
    if ct in ["L2_3IT", "L5ET", "L5IT", "L6IT"]
]
target = np.array([20 if x == target_cell_type else 0 for x in adata.obs_names])
intermediate, designed_sequences = crested.tl.enhancer_design_in_silico_evolution(
    model=model,
    target=target,  # our optimization function now expects a target vector instead of a class index
    n_sequences=1,
    n_mutations=30,
    enhancer_optimizer=L2_optimizer,
    return_intermediate=True,
    no_mutation_flanks=(807, 807),
    acgt_distribution=acgt_distribution,
    classes_of_interest=classes_of_interest,  # additional kwargs will be passed to the optimizer
)
idx = 0
prediction = crested.tl.predict(designed_sequences[idx], model=model)
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
../_images/35512ba4d742b81c2e338db5fc1be404b314ded88d9fd721a3eb1b6d3a6a7f09.png
class_idx = list(adata.obs_names.get_indexer(["L2_3IT", "L5ET", "L5IT", "L6IT"]))

scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    designed_sequences[idx],
    target_idx=class_idx,
    model=model,
)
2025-02-04T11:14:03.098502+0100 INFO Calculating contribution scores for 4 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["L2_3IT", "L5ET", "L5IT", "L6IT"],
    zoom_n_bases=500,
    title="",
)
../_images/195dcb6f794bc5c3641a29f9ddf31f7836a7659bc0fc1ddb44bbb1e637f50e5e.png