Introduction to CREsted with Peak Regression#

In this introductory notebook, we will train a peak regression model on the mouse BICCN data and inspect the results to get a feel for the capabilities of the CREsted package.

Import Data#

For this tutorial, we will use the mouse BICCN dataset which is available in the get_dataset() function.
To train a CREsted peak regression model on your data, you need:

  1. A consensus regions BED file containing all the regions of interest accross cell types.

  2. A folder containing the bigwig files per cell type. Each file should be named according to the cell type: {cell type name}.bw.

  3. A genome fasta file and optionally a chromosome sizes file.

You could use a tool like SnapATAC2 to generate the consensus regions and bigwig files from your own data.

from pathlib import Path
import numpy as np

%matplotlib inline
import matplotlib
from sklearn.metrics import pairwise
import anndata as ad
import keras
import crested
# Set the font type to ensure text is saved as whole words
matplotlib.rcParams["pdf.fonttype"] = 42  # Use TrueType fonts instead of Type 3 fonts
matplotlib.rcParams["ps.fonttype"] = 42  # For PostScript as well, if needed

Download the tutorial data.
For this tutorial we will be training on the ‘cut sites’, but the ‘coverage’ data are also available (an older version of the tutorial would train on the coverage).

bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")

Or if you have the data already available:

bigwigs_folder = "../../../../mouse/biccn/bigwigs/fragments_bws/"
regions_file = "../../../..//mouse/biccn/consensus_peaks_inputs.bed"
chromsizes_file = "../../../..//mouse/biccn/mm.chrom.sizes"

By loading our genome in the crested.Genome class and setting it with register_genome(), the genome is automatically used in all functions throughout CREsted. If you don’t provide the chromomsome sizes, they will be automatically calculated from the fasta.

Note

Any function or class that expects a genome object can still accept a genome object as explicit input even if one was already registered. In that case, the input will be used instead of the registered genome.

# Set the genome
genome_dir = Path("../../../../mouse/biccn/")
genome = crested.Genome(genome_dir / "mm10.fa", genome_dir / "mm10.chrom.sizes")
crested.register_genome(genome)
2024-12-12T10:18:24.298297+0100 INFO Genome mm10 registered.

We can use the import_bigwigs() function to import bigwigs per cell type and a consensus regions BED file into an anndata.AnnData object, with the imported cell types as the AnnData.obs and the consensus peak regions as the AnnData.var.

Optionally, provide a chromsizes file to filter out regions that are not within the chromsizes.

adata = crested.import_bigwigs(
    bigwigs_folder=bigwigs_folder,
    regions_file=regions_file,
    target_region_width=1000,  # optionally, use a different width than the consensus regions file (500bp) for the .X values calculation
    target="count",  # or "max", "mean", "logcount" --> what we will be predicting
)
adata
2024-12-12T10:18:30.706855+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
    obs: 'file_path'
    var: 'chr', 'start', 'end'

To train a model, we always need to add a split column to our dataset, which we can do using crested.pp.train_val_test_split().
This will add a column to the AnnData.obs with the split type for each region (train, val, or test).

# Choose the chromosomes for the validation and test sets
crested.pp.train_val_test_split(
    adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
)

# Alternatively, We can split randomly on the regions
# crested.pp.train_val_test_split(
#     adata, strategy="region", val_size=0.1, test_size=0.1, random_state=42
# )

print(adata.var["split"].value_counts())
adata.var
split
train    440993
val       56064
test      49936
Name: count, dtype: int64
chr start end split
region
chr1:3093998-3096112 chr1 3093998 3096112 train
chr1:3094663-3096777 chr1 3094663 3096777 train
chr1:3111367-3113481 chr1 3111367 3113481 train
chr1:3112727-3114841 chr1 3112727 3114841 train
chr1:3118939-3121053 chr1 3118939 3121053 train
... ... ... ... ...
chrX:169878506-169880620 chrX 169878506 169880620 train
chrX:169879374-169881488 chrX 169879374 169881488 train
chrX:169924670-169926784 chrX 169924670 169926784 train
chrX:169947743-169949857 chrX 169947743 169949857 train
chrX:169950171-169952285 chrX 169950171 169952285 train

546993 rows × 4 columns

Preprocessing#

Region Width#

For this example we’re interested in training on wider regions than our consensus regions file (500bp) to also include some sequence information from the tails of our peaks.

We change it to 2114 bp regions since that is what the Chrombpnet architecture was originally trained on and that’s what we’ll be using. This is not fixed and can be adapted to what you prefer, as long as it is compatible with the model architecture.

Wider regions will mean that you don’t only include sequence information from the center of the peaks and could effectively increase your dataset size if the tails of the peak include meaningful information, but could also introduce noise if the tails are not informative.
Wider regions will also increase the computational cost of training the model.

crested.pp.change_regions_width(
    adata,
    2114,
)  # change the adata width of the regions to 2114bp

Peak Normalization#

Additionally, we can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks() function.

This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.

crested.pp.normalize_peaks(
    adata, top_k_percent=0.03
)  # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01
2024-12-12T10:19:01.511397+0100 INFO Filtering on top k Gini scores...
2024-12-12T10:19:05.058074+0100 INFO Added normalization weights to adata.obsm['weights']...
chr start end split
region
chr5:76656624-76658738 chr5 76656624 76658738 train
chr13:30900787-30902901 chr13 30900787 30902901 train
chr9:65586049-65588163 chr9 65586049 65588163 test
chr9:65586556-65588670 chr9 65586556 65588670 test
chr9:65587095-65589209 chr9 65587095 65589209 test
... ... ... ... ...
chr9:65459289-65461403 chr9 65459289 65461403 test
chr9:65459852-65461966 chr9 65459852 65461966 test
chr5:76587680-76589794 chr5 76587680 76589794 train
chr9:65523082-65525196 chr9 65523082 65525196 test
chr19:18337654-18339768 chr19 18337654 18339768 train

48308 rows × 4 columns

We can visualize the normalization factor for each cell type using the crested.pl.bar.normalization_weights() function to inspect which cell type peaks were up/down weighted.

crested.pl.bar.normalization_weights(
    adata, title="Normalization Weights per Cell Type", x_label_rotation=90
)
../_images/7f4cf2858c36e760d033e8a06fa6a3a63dbe5cb54c025d95c61d4c390c2347a0.png

There is no single best way to preprocess your data, so we recommend experimenting with different preprocessing steps to see what works best for your data.
Likewise there is no single best training approach, so we recommend experimenting with different training strategies.

# Save the final preprocessing results
adata.write_h5ad("mouse_cortex.h5ad")

Model Training#

The entire CREsted workflow is built around the crested.tl.Crested() class. Everything that requires a model (training, evaluation, prediction) is done through this class.
This class has a couple of required arguments:

  • data: the crested.tl.data.AnnDataModule object containing all the data (anndata, genome) and dataloaders that specify how to load the data.

  • model: the keras.Model object containing the model architecture.

  • config: the crested.tl.TaskConfig object containing the optimizer, loss function, and metrics to use in training.

Generally you wouldn’t run these steps in a notebook, but rather in a script or a python file so you could run it on a cluster or in the background.

Data#

We’ll start by initializing the crested.tl.data.AnnDataModule object with our data.
This will tell our model how to load the data and what data to load during fitting/evaluation. The main arguments to supply are the adata object, the genome file path, and the batch_size.
Other optional arguments are related to the training data loading (e.g. shuffling, whether to load the sequences into memory, …).

The genome file you need to provide yourself as this is not included in the crested package.

# read in your preprocessed data
adata = ad.read_h5ad("mouse_cortex.h5ad")
datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=256,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

Model definition#

Next, we’ll define the model architecture. This is a standard Keras model definition, so you can provide your own model definition if you like.
Alternatively, there are a couple of ready-to-use models available in the crested.tl.zoo module.
Each of them require the width of the input sequences and the number of output classes (your Anndata.obs) as arguments.

# Load chrombpnet architecture for a dataset with 2114bp regions and 19 cell types
model_architecture = crested.tl.zoo.chrombpnet(
    seq_len=2114, num_classes=len(list(adata.obs_names))
)

TaskConfig#

The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our ‘task’).
Some default configurations are available for some common tasks such as ‘topic_classification’ and ‘peak_regression’, which you can load using the crested.tl.default_configs() function.

# Load the default configuration for training a peak regression model
config = crested.tl.default_configs(
    "peak_regression"
)  # or "topic_classification" for topic classification
print(config)

# If you want to change some small parameters to an existing config, you can do it like this
# For example, the default learning rate is 0.001, but you can change it to 0.0001
# config.optimizer.learning_rate = 0.0001
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1ace2510>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1ae592d0>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Alternatively, you can create your own TaskConfig object and specify the optimizer, loss function, and metrics yourself if you want to do something completely custom.

# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1ad14410>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1ad14a10>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Training#

Now we’re ready to train our model. We’ll create a Crested object with the data, model, and config objects we just created.
Then, we can call the fit() method to train the model.
Read the documentation for more information on all available arguments to customize your training (e.g. augmentations, early stopping, checkpointing, …).

By default:

  1. The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.

  2. Every best model is saved based on the validation loss.

  3. The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.

Note

If you specify the same project_name and run_name as a previous run, then CREsted will assume that you want to continue training and will load the last available model checkpoint from the {project_name}/{run_name} folder and continue from that epoch.

# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger="wandb",  # or 'wandb', 'tensorboard'
    seed=7,  # For reproducibility
)
2024-12-12T10:19:08.451626+0100 WARNING Output directory mouse_biccn/basemodel/checkpoints already exists. Will continue training from epoch 18.
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)
Finishing last run (ID:t1jx172q) before initializing another...
View run basemodel at: https://wandb.ai/kemp/mouse_biccn/runs/t1jx172q
View project at: https://wandb.ai/kemp/mouse_biccn
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20241209_201427-t1jx172q/logs
Successfully finished last run (ID:t1jx172q). Initializing new run:
Tracking run with wandb version 0.18.3
Run data is saved locally in /data/projects/c04/cbd-saerts/nkemp/software/CREsted/docs/tutorials/wandb/run-20241209_201453-jciwxmaz
2024-12-09T20:14:57.994147+0100 WARNING Model does not have an optimizer. Please compile the model before training.
None
2024-12-09T20:14:58.159088+0100 INFO Loading sequences into memory...
2024-12-09T20:15:05.095428+0100 INFO Loading sequences into memory...
Epoch 1/60
3445/3446 ━━━━━━━━━━━━━━━━━━━ 0s 223ms/step - concordance_correlation_coefficient: 0.5750 - cosine_similarity: 0.8166 - loss: -0.3449 - mean_absolute_error: 3.2040 - mean_squared_error: 45.8918 - pearson_correlation: 0.6825 - pearson_correlation_log: 0.5407 - zero_penalty_metric: 1136.6290
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 833s 232ms/step - concordance_correlation_coefficient: 0.5750 - cosine_similarity: 0.8166 - loss: -0.3449 - mean_absolute_error: 3.2039 - mean_squared_error: 45.8899 - pearson_correlation: 0.6825 - pearson_correlation_log: 0.5407 - zero_penalty_metric: 1136.6272 - val_concordance_correlation_coefficient: 0.4512 - val_cosine_similarity: 0.6731 - val_loss: 0.5619 - val_mean_absolute_error: 9.3106 - val_mean_squared_error: 198.5345 - val_pearson_correlation: 0.5837 - val_pearson_correlation_log: 0.4806 - val_zero_penalty_metric: 908.3172 - learning_rate: 0.0010
Epoch 2/60
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 783s 227ms/step - concordance_correlation_coefficient: 0.6789 - cosine_similarity: 0.8460 - loss: -0.4591 - mean_absolute_error: 2.8753 - mean_squared_error: 36.1717 - pearson_correlation: 0.7676 - pearson_correlation_log: 0.6078 - zero_penalty_metric: 1115.6266 - val_concordance_correlation_coefficient: 0.6851 - val_cosine_similarity: 0.8106 - val_loss: -0.3190 - val_mean_absolute_error: 3.2301 - val_mean_squared_error: 40.4674 - val_pearson_correlation: 0.7418 - val_pearson_correlation_log: 0.5651 - val_zero_penalty_metric: 1070.3921 - learning_rate: 0.0010
Epoch 3/60
3446/3446 ━━━━━━━━━━━━━━━━━━━━ 783s 227ms/step - concordance_correlation_coefficient: 0.7293 - cosine_similarity: 0.8592 - loss: -0.4970 - mean_absolute_error: 2.7146 - mean_squared_error: 31.4138 - pearson_correlation: 0.8026 - pearson_correlation_log: 0.6242 - zero_penalty_metric: 1111.0103 - val_concordance_correlation_coefficient: 0.7375 - val_cosine_similarity: 0.8421 - val_loss: -0.3265 - val_mean_absolute_error: 3.3975 - val_mean_squared_error: 35.2583 - val_pearson_correlation: 0.7845 - val_pearson_correlation_log: 0.5985 - val_zero_penalty_metric: 1198.2017 - learning_rate: 0.0010
Epoch 4/60
2191/3446 ━━━━━━━━━━━━━━━━━━━━ 4:40 223ms/step - concordance_correlation_coefficient: 0.7562 - cosine_similarity: 0.8663 - loss: -0.5173 - mean_absolute_error: 2.6118 - mean_squared_error: 28.6017 - pearson_correlation: 0.8187 - pearson_correlation_log: 0.6330 - zero_penalty_metric: 1106.6896
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ sequence            │ (None, 2114, 4)   │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d (Conv1D)     │ (None, 2114, 512) │     10,240 │ sequence[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalization │ (None, 2114, 512) │      2,048 │ conv1d[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation          │ (None, 2114, 512) │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout (Dropout)   │ (None, 2114, 512) │          0 │ activation[0][0]  │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1conv         │ (None, 2110, 512) │    786,432 │ dropout[0][0]     │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1bn           │ (None, 2110, 512) │      2,048 │ bpnet_1conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1activation   │ (None, 2110, 512) │          0 │ bpnet_1bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1crop         │ (None, 2110, 512) │          0 │ dropout[0][0]     │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 2110, 512) │          0 │ bpnet_1activatio… │
│                     │                   │            │ bpnet_1crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_1dropout      │ (None, 2110, 512) │          0 │ add[0][0]         │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2conv         │ (None, 2102, 512) │    786,432 │ bpnet_1dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2bn           │ (None, 2102, 512) │      2,048 │ bpnet_2conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2activation   │ (None, 2102, 512) │          0 │ bpnet_2bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2crop         │ (None, 2102, 512) │          0 │ bpnet_1dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 2102, 512) │          0 │ bpnet_2activatio… │
│                     │                   │            │ bpnet_2crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_2dropout      │ (None, 2102, 512) │          0 │ add_1[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3conv         │ (None, 2086, 512) │    786,432 │ bpnet_2dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3bn           │ (None, 2086, 512) │      2,048 │ bpnet_3conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3activation   │ (None, 2086, 512) │          0 │ bpnet_3bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3crop         │ (None, 2086, 512) │          0 │ bpnet_2dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 2086, 512) │          0 │ bpnet_3activatio… │
│                     │                   │            │ bpnet_3crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_3dropout      │ (None, 2086, 512) │          0 │ add_2[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4conv         │ (None, 2054, 512) │    786,432 │ bpnet_3dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4bn           │ (None, 2054, 512) │      2,048 │ bpnet_4conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4activation   │ (None, 2054, 512) │          0 │ bpnet_4bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4crop         │ (None, 2054, 512) │          0 │ bpnet_3dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 2054, 512) │          0 │ bpnet_4activatio… │
│                     │                   │            │ bpnet_4crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_4dropout      │ (None, 2054, 512) │          0 │ add_3[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5conv         │ (None, 1990, 512) │    786,432 │ bpnet_4dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5bn           │ (None, 1990, 512) │      2,048 │ bpnet_5conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5activation   │ (None, 1990, 512) │          0 │ bpnet_5bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5crop         │ (None, 1990, 512) │          0 │ bpnet_4dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 1990, 512) │          0 │ bpnet_5activatio… │
│                     │                   │            │ bpnet_5crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_5dropout      │ (None, 1990, 512) │          0 │ add_4[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6conv         │ (None, 1862, 512) │    786,432 │ bpnet_5dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6bn           │ (None, 1862, 512) │      2,048 │ bpnet_6conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6activation   │ (None, 1862, 512) │          0 │ bpnet_6bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6crop         │ (None, 1862, 512) │          0 │ bpnet_5dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 1862, 512) │          0 │ bpnet_6activatio… │
│                     │                   │            │ bpnet_6crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_6dropout      │ (None, 1862, 512) │          0 │ add_5[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7conv         │ (None, 1606, 512) │    786,432 │ bpnet_6dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7bn           │ (None, 1606, 512) │      2,048 │ bpnet_7conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7activation   │ (None, 1606, 512) │          0 │ bpnet_7bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7crop         │ (None, 1606, 512) │          0 │ bpnet_6dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 1606, 512) │          0 │ bpnet_7activatio… │
│                     │                   │            │ bpnet_7crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_7dropout      │ (None, 1606, 512) │          0 │ add_6[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8conv         │ (None, 1094, 512) │    786,432 │ bpnet_7dropout[0… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8bn           │ (None, 1094, 512) │      2,048 │ bpnet_8conv[0][0] │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8activation   │ (None, 1094, 512) │          0 │ bpnet_8bn[0][0]   │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8crop         │ (None, 1094, 512) │          0 │ bpnet_7dropout[0… │
│ (Cropping1D)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 1094, 512) │          0 │ bpnet_8activatio… │
│                     │                   │            │ bpnet_8crop[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bpnet_8dropout      │ (None, 1094, 512) │          0 │ add_7[0][0]       │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 512)       │          0 │ bpnet_8dropout[0… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │      9,747 │ global_average_p… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 6,329,875 (24.15 MB)
 Trainable params: 6,320,659 (24.11 MB)
 Non-trainable params: 9,216 (36.00 KB)

Finetuning on cell type-specific regions#

Subsetting the consensuspeak set#

For peak regression models, we recommend to continue training the model trained on all consensuspeaks on a subset of cell type-specific regions. Since we are interested in understanding the enhancer code uniquely identifying the cell types in the dataset, finetuning on specific regions will allow us to approach that. We define specific regions as regions with a high Gini index, indicating that their peak distribution over all cell types will be skewed and specific for one or more cell types.

Read the documentation of the crested.pp.filter_regions_on_specificity() function for more information on how the filtering is done.

crested.pp.filter_regions_on_specificity(
    adata, gini_std_threshold=1.0
)  # All regions with a Gini index 1 std above the mean across all regions will be kept
adata
2024-12-12T10:19:15.431415+0100 INFO After specificity filtering, kept 91475 out of 546993 regions.
AnnData object with n_obs × n_vars = 19 × 91475
    obs: 'file_path'
    var: 'chr', 'start', 'end', 'split'
    obsm: 'weights'
adata.write_h5ad("mouse_cortex_filtered.h5ad")

Loading the pretrained model on all consensuspeaks and finetuning with lower learning rate#

datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=64,  # Recommended to go for a smaller batch size than in the basemodel
    max_stochastic_shift=3,
    always_reverse_complement=True,
)
# First load the pretrained model on all peaks
model_architecture = keras.models.load_model(
    "mouse_biccn/basemodel/checkpoints/17.keras",
    compile=False,  # Choose the basemodel with best validation loss/performance metrics
)
# Use the same config you used for the pretrained model. EXCEPT THE LEARNING RATE, make sure that is lower than it was on the epoch you select the model from.
optimizer = keras.optimizers.Adam(learning_rate=1e-4)  # Lower LR!
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x7f4a1665c410>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x7f4a1adafc90>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="finetuned_model",  # change to your liking
    logger="wandb",  # or 'wandb', 'tensorboard'
)
2024-12-12T10:19:16.196631+0100 WARNING Output directory mouse_biccn/finetuned_model/checkpoints already exists. Will continue training from epoch 9.
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)

Evaluate the model#

After training, we can evaluate the model on the test set using the test() method.
If we’re still in the same session, we can simply continue using the same object.
If not, we can load the model from disk using theload_model() method. This means that we have to create a new Crested object first.
However, this time, since the taskconfig and architecture are saved in the .keras file, we only have to provide our datamodule.

adata = ad.read_h5ad("mouse_cortex_filtered.h5ad")

datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=256,  # lower this if you encounter OOM errors
)
# load an existing model
evaluator = crested.tl.Crested(data=datamodule)

evaluator.load_model(
    "mouse_biccn/finetuned_model/checkpoints/02.keras",  # Load your model
    compile=True,
)

If you experimented with many different hyperparameters for your model, chances are that you will start overfitting on your validation dataset.
It’s therefore always a good idea to evaluate your model on the test set after getting good results on your validation data to see how well it generalizes to unseen data.

# evaluate the model on the test set
evaluator.test()
 3/33 ━━━━━━━━━━━━━━━━━━━ 1s 60ms/step - concordance_correlation_coefficient: 0.6685 - cosine_similarity: 0.8724 - loss: 0.5137 - mean_absolute_error: 1.6471 - mean_squared_error: 13.7273 - pearson_correlation: 0.7443 - pearson_correlation_log: 0.6123 - zero_penalty_metric: 2443.7920
33/33 ━━━━━━━━━━━━━━━━━━━━ 14s 116ms/step - concordance_correlation_coefficient: 0.7091 - cosine_similarity: 0.8744 - loss: 0.5054 - mean_absolute_error: 1.6253 - mean_squared_error: 13.6106 - pearson_correlation: 0.7706 - pearson_correlation_log: 0.6340 - zero_penalty_metric: 2470.7397
2024-12-12T10:20:51.416834+0100 INFO Test concordance_correlation_coefficient: 0.7221
2024-12-12T10:20:51.417376+0100 INFO Test cosine_similarity: 0.8707
2024-12-12T10:20:51.417635+0100 INFO Test loss: 0.5097
2024-12-12T10:20:51.417896+0100 INFO Test mean_absolute_error: 1.6882
2024-12-12T10:20:51.418280+0100 INFO Test mean_squared_error: 14.4946
2024-12-12T10:20:51.418600+0100 INFO Test pearson_correlation: 0.7813
2024-12-12T10:20:51.418892+0100 INFO Test pearson_correlation_log: 0.6345
2024-12-12T10:20:51.419171+0100 INFO Test zero_penalty_metric: 2302.9001

Predict#

After training, we can also use the predict() method to predict the labels for new data and add them as a layer to the AnnData object.
A common use case is to compare the predicted labels to the true labels for multiple trained models to see how well they compare.

We can initiate a new Crested object (if you have different data) or use the existing one.
Here we continue with the existing one since we’ll use the same data as we trained on.

# add predictions for model checkpoint to the adata
evaluator.predict(
    adata, model_name="biccn_model"
)  # adds the predictions to the adata.layers["biccn_model"]
357/358 ━━━━━━━━━━━━━━━━━━━ 0s 54ms/step
358/358 ━━━━━━━━━━━━━━━━━━━━ 25s 66ms/step
2024-12-12T10:21:16.424896+0100 INFO Adding predictions to anndata.layers[biccn_model].
adata.layers
Layers with keys: biccn_model

If you don’t want to predict on the entire dataset, you can also predict on a given sequence or region using the predict_sequence() or predict_regions() methods.

Many of the plotting functions in the crested.pl module can be used to visualize these model predictions.

Example predictions on test set regions#

It is always interesting to see how the model performs on unseen test set regions. It is recommended to always look at a few examples to spot potential biases, or trends that you do not expect.

# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df
chr start end split
region
chr18:3269690-3271804 chr18 3269690 3271804 test
chr18:3350307-3352421 chr18 3350307 3352421 test
chr18:3451398-3453512 chr18 3451398 3453512 test
chr18:3463977-3466091 chr18 3463977 3466091 test
chr18:3488308-3490422 chr18 3488308 3490422 test
... ... ... ... ...
chr9:124125533-124127647 chr9 124125533 124127647 test
chr9:124140961-124143075 chr9 124140961 124143075 test
chr9:124142793-124144907 chr9 124142793 124144907 test
chr9:124477280-124479394 chr9 124477280 124479394 test
chr9:124479548-124481662 chr9 124479548 124481662 test

8198 rows × 4 columns

# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
print(region)
crested.pl.bar.region_predictions(adata, region, title="Predictions vs Ground Truth")
chr18:3892771-3894885
2024-12-12T10:21:16.459431+0100 INFO Plotting bar plots for region: chr18:3892771-3894885, models: ['biccn_model']
../_images/94cc413dd1b18a22ae4e333e810e9e1a6e4b5d0021866a866262abd3d1dcc4cb.png
Example predictions on manually defined regions#
chrom = "chr3"  #'chr18'
start = 72535878 - 807  # 61107770
end = 72536378 + 807  # 61109884

sequence = genome.fetch(chrom, start, end).upper()

prediction = evaluator.predict_sequence(sequence)
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 863ms/step
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
../_images/205e31a22d27028f934660a20adcf1ca1f2af9eefd5a34f299c6c8fc0fa0c3f4.png

Prediction on gene locus#

We can also score a gene locus by using a sliding window over a predefined genomic range. We can compare those predictions then to the bigwig we did the predictions for, to see if the profile matches the CREsted predictions.

chrom = "chr4"
start = 91209533
end = 91374781

scores, coordinates, min_loc, max_loc, tss_position = evaluator.score_gene_locus(
    chr_name=chrom,
    gene_start=start,
    gene_end=end,
    class_name="Sst",
    strand="-",
    upstream=50000,
    downstream=10000,
    step_size=100,
)
bigwig = "../../../../mouse/biccn/bigwigs/bws/Sst.bw"
bw_values, midpoints = crested.utils.extract_bigwig_values_per_bp(bigwig, coordinates)
2024-12-12T10:21:22.649225+0100 WARNING extract_bigwig_values_per_bp() is deprecated. Please use crested.utils.read_bigwig_region(bw_file, (chr, start, end)) instead.
%matplotlib inline
crested.pl.hist.locus_scoring(
    scores,
    (min_loc, max_loc),
    gene_start=start,
    gene_end=end,
    title="CREsted prediction around Elavl2 gene locus for Sst",
    bigwig_values=bw_values,
    bigwig_midpoints=midpoints,
)
../_images/8d5d8aa9d45d23e18992476bb02d43627ead2ffc1ab25781a6125c61002837f9.png

Model performance on the entire test set#

After looking at specific instances, now we can look at the model performance on a larger scale.

First, we can check per cell type/class the correlation of predictions and peak heights over the peaks in the test set.

adata.layers
Layers with keys: biccn_model
classn = "L2_3IT"
crested.pl.scatter.class_density(
    adata,
    class_name=classn,
    model_names=["biccn_model"],
    split="test",
    log_transform=True,
    width=5,
    height=5,
)
2024-12-12T10:21:23.275696+0100 INFO Plotting density scatter for class: L2_3IT, models: ['biccn_model'], split: test
../_images/a1efc23c1377843015b8537a0502d05dcc3ad5a6379046c9104823a236c1cbbc.png

To now check the correlations between all classes, we can plot a heatmap to assess the model performance.

crested.pl.heatmap.correlations_predictions(
    adata,
    split="test",
    title="Correlations between Groundtruths and Predictions",
    x_label_rotation=90,
    width=5,
    height=5,
    log_transform=True,
    vmax=1,
    vmin=-0.15,
)
2024-12-12T10:21:23.381532+0100 INFO Plotting heatmap correlations for split: test, models: ['biccn_model']
../_images/ebad3846e433e74ded166c0e3f4b28da7d91993ddb8a4dde11346d1fd77ed9f2.png

It is also recommended to compare this heatmap to the self correlation plot of the peaks themselves. If peaks between cell types are correlated, then it is expected that predictions from non-matching classes for correlationg cell types will also be high.

crested.pl.heatmap.correlations_self(
    adata,
    title="Self Correlation Heatmap",
    x_label_rotation=90,
    width=5,
    height=5,
    vmax=1,
    vmin=-0.15,
)
../_images/b706b008cfefed5870d1f04ae147d199f81ec9b79acc3671e9186c19996a2b4d.png

Sequence contribution scores#

We can calculate the contribution scores for a sequence of interest using the calculate_contribution_scores_sequence() method.
You always need to ensure that the sequence or region you provide is the same length as the model input (2114bp in our case).

Contribution scores on manually defined sequences#

# random sequence of length 2114bp as an example
sequence = "A" * 2114

scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
    sequence, class_names=["Astro", "Endo"]
)  # focus on two cell types of interest
2024-12-12T10:21:24.012136+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).

Contribution scores on manually defined genomic regions#

Alternatively, you can calculate contribution scores for regions of interest using the calculate_contribution_scores_regions() method.
These regions don’t have to be in your original dataset, as long as they exist in the genome file that you provided to the AnnDataModule and they are the same length as the model input.

# focus on two cell types of interest
regions_of_interest = [
    "chr18:61107770-61109884"
]  # FIRE enhancer region, should only have motifs in Micro_PVM
classes_of_interest = ["Astro", "Micro_PVM", "L5ET"]
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
    region_idx=regions_of_interest, class_names=classes_of_interest
)
2024-12-12T10:21:28.374958+0100 INFO Calculating contribution scores for 3 class(es) and 1 region(s).

Contribution scores for regions can be plotted using the crested.pl.patterns.contribution_scores() function.
This will generate a plot per class per region.

%matplotlib inline
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=regions_of_interest,
    class_labels=classes_of_interest,
    zoom_n_bases=500,
    title="FIRE Enhancer Region",
)  # zoom in on the center 500bp
../_images/a8f316fdc7ff20c5c56b74344a36a96db5172b47e87d8c8acd6289afcf91a85e.png

Contribution scores on random test set regions#

# plot predictions vs ground truth for a random region in the test set defined by index
idx = 21
region = test_df.index[idx]
classes_of_interest = ["L2_3IT", "Pvalb"]
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
    region_idx=region, class_names=classes_of_interest
)
2024-12-12T10:21:39.277930+0100 INFO Calculating contribution scores for 2 class(es) and 1 region(s).
%matplotlib inline
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=[region],
    class_labels=classes_of_interest,
    zoom_n_bases=500,
    title="Test set region",
)  # zoom in on the center 500bp
../_images/805141f58942e73ee5d82985d9827fe411e44245d3420e8da35f95927981b913.png

Enhancer design [more updates soon]#

Load data and model#

adata = ad.read_h5ad("mouse_cortex_filtered.h5ad")

datamodule = crested.tl.data.AnnDataModule(
    adata,
)
# load an existing model
evaluator = crested.tl.Crested(data=datamodule)

evaluator.load_model(
    "mouse_biccn/finetuned_model/checkpoints/02.keras",  # Load your model
    compile=True,
)

Sequence evolution#

We can create synthetic enhancers for a specified class using in silico evolution with the enhancer_design_in_silico_evolution() method.

designed_sequences = evaluator.enhancer_design_in_silico_evolution(
    target_class="L5ET", n_sequences=5, n_mutations=10, target_len=500
)
2024-12-12T10:21:46.923792+0100 INFO Loading sequences into memory...
prediction = evaluator.predict_sequence(designed_sequences[1])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 267ms/step
../_images/95ecb32ea56a9161550277336c427098b07c0b4adff1cdde1d2dd79b9b92492b.png
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
    designed_sequences[0], class_names=["L5ET"]
)  # focus on two cell types of interest
2024-12-12T10:22:52.478705+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["L5ET"],
    zoom_n_bases=500,
    title="synth L5ET",
)  # zoom in on the center 500bp
../_images/33942bcd5881c053a460bff319bc321dd1ed78813e171353693db5ef2bbcc98d.png

Motif embedding#

designed_sequences = evaluator.enhancer_design_motif_implementation(
    patterns={
        "SOX10": "AACAATGGCCCCATTGT",
        "CREB5": "ATGACATCA",
    },
    target_class="Oligo",
    n_sequences=5,
    target_len=500,
)
prediction = evaluator.predict_sequence(designed_sequences[0])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
../_images/22d18d454ad05b14b91962b803e51a54d4cf387734da11f9725d6ba812668068.png
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
    designed_sequences[0], class_names=["Oligo"]
)  # focus on two cell types of interest
2024-12-12T10:23:01.302410+0100 INFO Calculating contribution scores for 1 class(es) and 1 region(s).
%matplotlib inline
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["Oligo"],
    zoom_n_bases=500,
    title="synth Oligo",
)
../_images/7746dc78b997eb24a2b5db1cdbe297607ff10048e59f21992b908d2cd242b3af.png

L2 optimizer#

def L2_distance(
    mutated_predictions: np.ndarray,
    original_prediction: np.ndarray,
    target: np.ndarray,
    classes_of_interest: list[int],
):
    """Calculate the L2 distance between the mutated predictions and the target class"""
    if len(original_prediction.shape) == 1:
        original_prediction = original_prediction[None]
    L2_sat_mut = pairwise.euclidean_distances(
        mutated_predictions[:, classes_of_interest],
        target[classes_of_interest].reshape(1, -1),
    )
    L2_baseline = pairwise.euclidean_distances(
        original_prediction[:, classes_of_interest],
        target[classes_of_interest].reshape(1, -1),
    )
    return np.argmax((L2_baseline - L2_sat_mut).squeeze())


L2_optimizer = crested.utils.EnhancerOptimizer(optimize_func=L2_distance)

target_cell_type = "L2_3IT"

classes_of_interest = [
    i
    for i, ct in enumerate(adata.obs_names)
    if ct in ["L2_3IT", "L5ET", "L5IT", "L6IT"]
]
target = np.array([20 if x == target_cell_type else 0 for x in adata.obs_names])
intermediate, designed_sequences = evaluator.enhancer_design_in_silico_evolution(
    target_class=None,
    n_sequences=5,
    n_mutations=30,
    enhancer_optimizer=L2_optimizer,
    target=target,
    return_intermediate=True,
    no_mutation_flanks=(807, 807),
    classes_of_interest=classes_of_interest,
)
idx = 0
prediction = evaluator.predict_sequence(designed_sequences[idx])
crested.pl.bar.prediction(prediction, classes=list(adata.obs_names))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step
../_images/5519572422baf3661725c26fc0a9dcab6bd3f26b7172318a7e73d5e26cc23de8.png
scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_sequence(
    designed_sequences[idx], class_names=["L2_3IT", "L5ET", "L5IT", "L6IT"]
)  # focus on two cell types of interest
2024-12-12T10:24:07.129049+0100 INFO Calculating contribution scores for 4 class(es) and 1 region(s).
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels="",
    class_labels=["L2_3IT", "L5ET", "L5IT", "L6IT"],
    zoom_n_bases=500,
    title="",
)
../_images/fe63e087de592ae557e454669b3d9e5d04ebe5b11f53d1303f80d8e93b6a795b.png
wandb: 🚀 View run finetuned_model at: https://wandb.ai/kemp/mouse_biccn/runs/it1js3u7
wandb: Find logs at: wandb/run-20241212_101919-it1js3u7/logs