Finetune Borzoi for scATAC peaks#

In this tutorial, we’ll show how to finetune the Borzoi model to do peak regression on scATAC data.

import os
import zipfile
import tempfile

import pandas as pd
import keras
import crested
resources_dir = "/staging/leuven/res_00001/genomes/mus_musculus/mm10_ucsc/fasta/"
genome_file = os.path.join(resources_dir, "mm10.fa")
chromsizes_file = os.path.join(resources_dir, "mm10.chrom.sizes")
folds_file = "consensus_peaks_borsplit.bed"
genome = crested.Genome(genome_file, chromsizes_file)
crested.register_genome(genome)
2025-02-14T13:07:13.655587+0100 INFO Genome mm10 registered.

Read in scATAC data#

We’ll use the same dataset as used in the default tutorial, the mouse BICCN dataset, derived from the brain cortex.

bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
adata = crested.import_bigwigs(
    bigwigs_folder=bigwigs_folder,
    regions_file=regions_file,
    target_region_width=1000,
    target="count",
)
adata
2025-02-14T13:08:31.442981+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
    obs: 'file_path'
    var: 'chr', 'start', 'end'

Add train/val/test split#

Generally, for finetuning, it’s recommended to use the train/test split from the original model, like Borzoi here.
This can be derived by intersecting your consensus peaks with sequences_mouse.bed from the Borzoi repository, like with BEDTools:

regions_file="consensus_peaks_biccn.bed" # regions_file from crested.get_dataset()
folds_file="sequences_mouse.bed" # From Borzoi repo
output_file="consensus_peaks_borsplit.bed"

grep fold3 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'test/ > ${output_file}
grep fold4 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'val/ >> ${output_file}
for i in 0 1 2 5 6 7; do
    grep fold${i} ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'train/ >> ${output_file}
done

folds = pd.read_csv(
    folds_file, sep="\t", names=["name", "split"], usecols=[3, 4]
).set_index("name")
print(
    f"% of regions found in folds file: {adata.var_names.isin(folds.index).sum()/adata.n_vars*100:.3f}%"
)
% of regions found in folds file: 99.425%
# Drop regions not in any folds
print(
    f"Dropping {(~adata.var_names.isin(folds.index)).sum()} regions because they are not in any fold."
)
adata = adata[:, adata.var_names.isin(folds.index)].copy()

# Add fold data to var
adata.var = adata.var.join(folds)

# Check result
adata.var["split"].value_counts(dropna=False)
Dropping 3146 regions because they are not in any fold.
split
train    412229
val       72744
test      58874
Name: count, dtype: int64

Alternatively, you could use the default train/test split function set a chromosome-based or random split:

# crested.pp.train_val_test_split(
#     adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
# )

Preprocessing#

For the preprocessing, we’ll again follow the default steps, except for the adjusted input size.

Region width#

In this example, we’ll use 2048bp inputs, to align with the 2114bp input size of the standard CNN peak regression models while staying within a multiple of 128. Therefore, we’ll need to resize our regions:

crested.pp.change_regions_width(
    adata,
    2048,
)

Peak normalization#

We can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks() function.

This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.

crested.pp.normalize_peaks(
    adata, top_k_percent=0.03
)  # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01
2025-02-14T13:09:21.187809+0100 INFO Filtering on top k Gini scores...
2025-02-14T13:09:28.069781+0100 INFO Added normalization weights to adata.obsm['weights']...
chr start end split
region
chr9:76566142-76568190 chr9 76566142 76568190 train
chr5:98328510-98330558 chr5 98328510 98330558 train
chr5:98347819-98349867 chr5 98347819 98349867 train
chr13:34635167-34637215 chr13 34635167 34637215 train
chr13:34642109-34644157 chr13 34642109 34644157 train
... ... ... ... ...
chr13:34344270-34346318 chr13 34344270 34346318 train
chr5:98166140-98168188 chr5 98166140 98168188 train
chr5:98166667-98168715 chr5 98166667 98168715 train
chr13:34344974-34347022 chr13 34344974 34347022 train
chr5:98185712-98187760 chr5 98185712 98187760 train

48075 rows × 4 columns

We can visualize the normalization factor for each cell type using the crested.pl.bar.normalization_weights() function to inspect which cell type peaks were up/down weighted.

%matplotlib inline
crested.pl.bar.normalization_weights(adata, title="Normalization weights per cell type")
../_images/481f57e6f4eb91731cadd00954d32ed089a6cceae5cd4973b3c328417ac9229b.png

Load in model#

We load in the Borzoi model’s weights in its architecture, with one change - the input length. All of Borzoi’s layers are width-independent, so the length can be set to any value divisible by the internal bin size (128bp).
target_length is set to the total number of output bins (64 bins of 32bp makes 2048bp output), since no cropping is needed when predicting local features.
num_classes is set to the original size simply so that there are no weight shape mismatches when loading the initial weights; the head created based on num_classes will be replaced by a new head for the number of cell types we’d like to predict below.

# Create default Borzoi architecture, with shrunk input size and target_length
base_model_architecture = crested.tl.zoo.borzoi(
    seq_len=2048, target_length=2048 // 32, num_classes=2608
)

To load in the weights, we can’t directly load the model from the .keras file, since that fixes the input length at the previously set value (524288bp). However, we can extract the model.weights.h5 file containing only the weights and use that.

# Load pretrained Borzoi weights
model_file, _ = crested.get_model("Borzoi_mouse_rep0")
# Put weights into base architecture
with zipfile.ZipFile(
    model_file
) as model_archive, tempfile.TemporaryDirectory() as tmpdir:
    model_weights_path = model_archive.extract("model.weights.h5", tmpdir)
    base_model_architecture.load_weights(model_weights_path)

Now that we have the base model with the adjusted input shape, we need to adjust the final layers to return a value for each cell type per region, instead of per-bin values. Therefore, we drop the final head, add a flatten layer after the model’s final embedding, and add a new head predicting adata.n_obs values.

# Replace track head by flatten+dense to predict single vector of scalars per region
## Get last layer before head
current = base_model_architecture.get_layer("final_conv_activation").output
## Flatten and add new layer
current = keras.layers.Flatten()(current)
current = keras.layers.Dense(adata.n_obs, activation="softplus", name="dense_out")(
    current
)

# Turn into model
model_architecture = keras.Model(
    inputs=base_model_architecture.inputs, outputs=current, name="Borzoi_scalar"
)
print(model_architecture.summary())
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add[0][0]         │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 16, 1536)  │          0 │ add[0][0],        │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_1[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 16, 1536)  │          0 │ add_1[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_2[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 16, 1536)  │          0 │ add_2[0][0],      │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_3[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 16, 1536)  │          0 │ add_3[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_4[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 16, 1536)  │          0 │ add_4[0][0],      │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_5[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 16, 1536)  │          0 │ add_5[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_6[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 16, 1536)  │          0 │ add_6[0][0],      │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_7[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_8 (Add)         │ (None, 16, 1536)  │          0 │ add_7[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_8[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_9 (Add)         │ (None, 16, 1536)  │          0 │ add_8[0][0],      │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_9[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_10 (Add)        │ (None, 16, 1536)  │          0 │ add_9[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_10[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_11 (Add)        │ (None, 16, 1536)  │          0 │ add_10[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_11[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_12 (Add)        │ (None, 16, 1536)  │          0 │ add_11[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_12[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_13 (Add)        │ (None, 16, 1536)  │          0 │ add_12[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_13[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_14 (Add)        │ (None, 16, 1536)  │          0 │ add_13[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_14[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_15 (Add)        │ (None, 16, 1536)  │          0 │ add_14[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_15[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d       │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_16 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d[0]… │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_16[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_1     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_17 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_1[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_17[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten (Flatten)   │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten[0][0]     │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None

Model Training#

Parameters#

The DataModule and TaskConfig let you set standard training parameters, like batch size and learning rate.
We use the same parameters as with peak regression in the default tutorial, except for a lower learning rate to match the fact that we are starting from a pre-trained model.

datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x1511019ba0d0>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x1511019bf990>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Finetune on full peak set#

By default:

  1. The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.

  2. Every best model is saved based on the validation loss.

  3. The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.

# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="testrun",
    logger="wandb",
)
# train the model
trainer.fit(epochs=10)
wandb version 0.19.6 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.17.2
Run data is saved locally in /lustre1/project/stg_00002/lcb/cblaauw/rna_models/borzoi_atac/wandb/run-20250213_161634-veap7cez
Syncing run testrun to Weights & Biases (docs)
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_18 (Add)        │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add_18[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_19 (Add)        │ (None, 16, 1536)  │          0 │ add_18[0][0],     │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_19[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_20 (Add)        │ (None, 16, 1536)  │          0 │ add_19[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_20[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_21 (Add)        │ (None, 16, 1536)  │          0 │ add_20[0][0],     │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_21[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_22 (Add)        │ (None, 16, 1536)  │          0 │ add_21[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_22[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_23 (Add)        │ (None, 16, 1536)  │          0 │ add_22[0][0],     │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_23[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_24 (Add)        │ (None, 16, 1536)  │          0 │ add_23[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_24[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_25 (Add)        │ (None, 16, 1536)  │          0 │ add_24[0][0],     │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_25[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_26 (Add)        │ (None, 16, 1536)  │          0 │ add_25[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_26[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_27 (Add)        │ (None, 16, 1536)  │          0 │ add_26[0][0],     │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_27[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_28 (Add)        │ (None, 16, 1536)  │          0 │ add_27[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_28[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_29 (Add)        │ (None, 16, 1536)  │          0 │ add_28[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_29[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_30 (Add)        │ (None, 16, 1536)  │          0 │ add_29[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_30[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_31 (Add)        │ (None, 16, 1536)  │          0 │ add_30[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_31[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_32 (Add)        │ (None, 16, 1536)  │          0 │ add_31[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_32[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_33 (Add)        │ (None, 16, 1536)  │          0 │ add_32[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_33[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_2     │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_34 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d_2[ │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_34[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_3     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_35 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_3[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_35[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten (Flatten)   │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten[0][0]     │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None
2025-02-13T16:17:01.947660+0100 INFO Loading sequences into memory...
2025-02-13T16:17:10.887463+0100 INFO Loading sequences into memory...
Epoch 1/10
25764/25765 ━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2751 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2008
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2748 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2008
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 969s 34ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2745 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2007 - val_concordance_correlation_coefficient: 0.8757 - val_cosine_similarity: 0.8763 - val_loss: -0.6293 - val_mean_absolute_error: 2.1752 - val_mean_squared_error: 18.0886 - val_pearson_correlation: 0.8776 - val_pearson_correlation_log: 0.6524 - val_zero_penalty_metric: 138.8675 - learning_rate: 1.0000e-05
Epoch 2/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 832s 32ms/step - concordance_correlation_coefficient: 0.8893 - cosine_similarity: 0.8847 - loss: -0.6589 - mean_absolute_error: 2.0946 - mean_squared_error: 15.5250 - pearson_correlation: 0.9042 - pearson_correlation_log: 0.6722 - zero_penalty_metric: 133.3652 - val_concordance_correlation_coefficient: 0.8768 - val_cosine_similarity: 0.8804 - val_loss: -0.6392 - val_mean_absolute_error: 2.1034 - val_mean_squared_error: 16.3852 - val_pearson_correlation: 0.8845 - val_pearson_correlation_log: 0.6555 - val_zero_penalty_metric: 139.2296 - learning_rate: 1.0000e-05
Epoch 3/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 825s 32ms/step - concordance_correlation_coefficient: 0.9066 - cosine_similarity: 0.8926 - loss: -0.6849 - mean_absolute_error: 1.9653 - mean_squared_error: 13.2411 - pearson_correlation: 0.9170 - pearson_correlation_log: 0.6835 - zero_penalty_metric: 132.5512 - val_concordance_correlation_coefficient: 0.8715 - val_cosine_similarity: 0.8828 - val_loss: -0.6440 - val_mean_absolute_error: 2.0995 - val_mean_squared_error: 16.5556 - val_pearson_correlation: 0.8848 - val_pearson_correlation_log: 0.6619 - val_zero_penalty_metric: 138.6013 - learning_rate: 1.0000e-05
Epoch 4/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 818s 32ms/step - concordance_correlation_coefficient: 0.9192 - cosine_similarity: 0.8986 - loss: -0.7056 - mean_absolute_error: 1.8737 - mean_squared_error: 11.8145 - pearson_correlation: 0.9268 - pearson_correlation_log: 0.6924 - zero_penalty_metric: 131.7001 - val_concordance_correlation_coefficient: 0.8761 - val_cosine_similarity: 0.8827 - val_loss: -0.6397 - val_mean_absolute_error: 2.0982 - val_mean_squared_error: 16.2530 - val_pearson_correlation: 0.8864 - val_pearson_correlation_log: 0.6568 - val_zero_penalty_metric: 138.3267 - learning_rate: 1.0000e-05
Epoch 5/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 826s 32ms/step - concordance_correlation_coefficient: 0.9276 - cosine_similarity: 0.9040 - loss: -0.7238 - mean_absolute_error: 1.7957 - mean_squared_error: 10.7533 - pearson_correlation: 0.9339 - pearson_correlation_log: 0.7007 - zero_penalty_metric: 130.6058 - val_concordance_correlation_coefficient: 0.8795 - val_cosine_similarity: 0.8823 - val_loss: -0.6346 - val_mean_absolute_error: 2.1381 - val_mean_squared_error: 16.7327 - val_pearson_correlation: 0.8824 - val_pearson_correlation_log: 0.6552 - val_zero_penalty_metric: 139.1519 - learning_rate: 1.0000e-05
Epoch 6/10
 6901/25765 ━━━━━━━━━━━━━━━━━━━━ 9:44 31ms/step - concordance_correlation_coefficient: 0.9324 - cosine_similarity: 0.9092 - loss: -0.7411 - mean_absolute_error: 1.7301 - mean_squared_error: 9.9792 - pearson_correlation: 0.9379 - pearson_correlation_log: 0.7097 - zero_penalty_metric: 128.9625

Further finetuning on specific regions#

We found that finetuning on the full peak set, then on the filtered peak set improved performance over training only on either set. Therefore, we’ll filter the peaks to keep only cell type-specific peaks and further finetune the model.

crested.pp.filter_regions_on_specificity(adata, gini_std_threshold=1.0)
2025-02-14T13:09:51.664914+0100 INFO After specificity filtering, kept 91002 out of 543847 regions.
datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=5e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x151101c6d890>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x151101c6e010>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
model_architecture = keras.models.load_model(
    "biccn_borzoi_atac/testrun/checkpoints/03.keras", compile=False
)
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="testrun_ft",
    logger="wandb",
)
# train the model
trainer.fit(epochs=5)
wandb version 0.19.6 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.17.2
Run data is saved locally in /lustre1/project/stg_00002/lcb/cblaauw/rna_models/borzoi_atac/wandb/run-20250214_131001-sdx53oa8
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_18 (Add)        │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add_18[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_19 (Add)        │ (None, 16, 1536)  │          0 │ add_18[0][0],     │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_19[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_20 (Add)        │ (None, 16, 1536)  │          0 │ add_19[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_20[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_21 (Add)        │ (None, 16, 1536)  │          0 │ add_20[0][0],     │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_21[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_22 (Add)        │ (None, 16, 1536)  │          0 │ add_21[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_22[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_23 (Add)        │ (None, 16, 1536)  │          0 │ add_22[0][0],     │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_23[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_24 (Add)        │ (None, 16, 1536)  │          0 │ add_23[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_24[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_25 (Add)        │ (None, 16, 1536)  │          0 │ add_24[0][0],     │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_25[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_26 (Add)        │ (None, 16, 1536)  │          0 │ add_25[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_26[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_27 (Add)        │ (None, 16, 1536)  │          0 │ add_26[0][0],     │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_27[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_28 (Add)        │ (None, 16, 1536)  │          0 │ add_27[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_28[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_29 (Add)        │ (None, 16, 1536)  │          0 │ add_28[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_29[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_30 (Add)        │ (None, 16, 1536)  │          0 │ add_29[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_30[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_31 (Add)        │ (None, 16, 1536)  │          0 │ add_30[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_31[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_32 (Add)        │ (None, 16, 1536)  │          0 │ add_31[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_32[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_33 (Add)        │ (None, 16, 1536)  │          0 │ add_32[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_33[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_2     │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_34 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d_2[ │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_34[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_3     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_35 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_3[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_35[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten (Flatten)   │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten[0][0]     │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None
2025-02-14T13:10:34.604198+0100 INFO Loading sequences into memory...
2025-02-14T13:10:45.359384+0100 INFO Loading sequences into memory...
Epoch 1/5
4197/4198 ━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4116 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1437
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4115 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1441
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 288s 48ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4114 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1445 - val_concordance_correlation_coefficient: 0.6791 - val_cosine_similarity: 0.8512 - val_loss: -0.5814 - val_mean_absolute_error: 1.7332 - val_mean_squared_error: 14.4574 - val_pearson_correlation: 0.7423 - val_pearson_correlation_log: 0.6029 - val_zero_penalty_metric: 312.9736 - learning_rate: 5.0000e-05
Epoch 2/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 139s 33ms/step - concordance_correlation_coefficient: 0.8034 - cosine_similarity: 0.9018 - loss: -0.7173 - mean_absolute_error: 1.4257 - mean_squared_error: 9.5457 - pearson_correlation: 0.8455 - pearson_correlation_log: 0.6463 - zero_penalty_metric: 318.2404 - val_concordance_correlation_coefficient: 0.7234 - val_cosine_similarity: 0.8542 - val_loss: -0.5963 - val_mean_absolute_error: 1.7099 - val_mean_squared_error: 13.3615 - val_pearson_correlation: 0.7549 - val_pearson_correlation_log: 0.6074 - val_zero_penalty_metric: 307.6308 - learning_rate: 5.0000e-05
Epoch 3/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8495 - cosine_similarity: 0.9229 - loss: -0.7697 - mean_absolute_error: 1.2824 - mean_squared_error: 7.5957 - pearson_correlation: 0.8779 - pearson_correlation_log: 0.6710 - zero_penalty_metric: 312.4368 - val_concordance_correlation_coefficient: 0.7091 - val_cosine_similarity: 0.8496 - val_loss: -0.5783 - val_mean_absolute_error: 1.7447 - val_mean_squared_error: 13.5709 - val_pearson_correlation: 0.7505 - val_pearson_correlation_log: 0.6042 - val_zero_penalty_metric: 309.9329 - learning_rate: 5.0000e-05
Epoch 4/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8773 - cosine_similarity: 0.9389 - loss: -0.8093 - mean_absolute_error: 1.1699 - mean_squared_error: 6.3396 - pearson_correlation: 0.8985 - pearson_correlation_log: 0.6910 - zero_penalty_metric: 309.6932 - val_concordance_correlation_coefficient: 0.7358 - val_cosine_similarity: 0.8513 - val_loss: -0.5890 - val_mean_absolute_error: 1.7230 - val_mean_squared_error: 13.0675 - val_pearson_correlation: 0.7589 - val_pearson_correlation_log: 0.6259 - val_zero_penalty_metric: 314.7453 - learning_rate: 5.0000e-05
Epoch 5/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8976 - cosine_similarity: 0.9504 - loss: -0.8389 - mean_absolute_error: 1.0856 - mean_squared_error: 5.4635 - pearson_correlation: 0.9137 - pearson_correlation_log: 0.7117 - zero_penalty_metric: 304.2629 - val_concordance_correlation_coefficient: 0.7264 - val_cosine_similarity: 0.8481 - val_loss: -0.5807 - val_mean_absolute_error: 1.7185 - val_mean_squared_error: 13.2077 - val_pearson_correlation: 0.7581 - val_pearson_correlation_log: 0.6187 - val_zero_penalty_metric: 309.3376 - learning_rate: 5.0000e-05

Run history:


batch/batch_step▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
batch/concordance_correlation_coefficient▁▁▂▂▂▂▂▂▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████
batch/cosine_similarity▁▁▁▁▁▁▁▁▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████
batch/loss████████▅▅▅▅▅▅▅▅▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/mean_absolute_error████████▆▅▅▅▅▅▅▅▄▃▄▄▄▄▄▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/mean_squared_error███▇▇▇▇▇▅▅▅▅▅▅▅▅▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/pearson_correlation▁▁▂▂▂▂▂▂▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████
batch/pearson_correlation_log▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
batch/zero_penalty_metric█▇▇▇████▇█▇▇▇▇▇▇▄▅▅▅▅▅▅▆▄▄▄▄▄▄▄▄▁▂▃▂▃▃▂▂
epoch/concordance_correlation_coefficient▁▁▄▄▆▆▇▇██
epoch/cosine_similarity▁▁▄▄▆▆▇▇██
epoch/epoch▁▁▃▃▅▅▆▆██
epoch/learning_rate▁▁▁▁▁▁▁▁▁▁
epoch/loss██▅▅▄▄▂▂▁▁
epoch/mean_absolute_error██▆▆▄▄▂▂▁▁
epoch/mean_squared_error██▅▅▃▃▂▂▁▁
epoch/pearson_correlation▁▁▄▄▆▆▇▇██
epoch/pearson_correlation_log▁▁▃▃▅▅▆▆██
epoch/val_concordance_correlation_coefficient▁▁▆▆▅▅██▇▇
epoch/val_cosine_similarity▄▄██▃▃▅▅▁▁
epoch/val_loss▇▇▁▁██▄▄▇▇
epoch/val_mean_absolute_error▆▆▁▁██▄▄▃▃
epoch/val_mean_squared_error██▂▂▄▄▁▁▂▂
epoch/val_pearson_correlation▁▁▆▆▄▄████
epoch/val_pearson_correlation_log▁▁▂▂▁▁██▆▆
epoch/val_zero_penalty_metric▆▆▁▁▃▃██▃▃
epoch/zero_penalty_metric██▇▇▅▅▃▃▁▁

Run summary:


batch/batch_step20990
batch/concordance_correlation_coefficient0.89607
batch/cosine_similarity0.9493
batch/loss-0.83633
batch/mean_absolute_error1.08951
batch/mean_squared_error5.49214
batch/pearson_correlation0.91258
batch/pearson_correlation_log0.71105
batch/zero_penalty_metric304.45786
epoch/concordance_correlation_coefficient0.89606
epoch/cosine_similarity0.9493
epoch/epoch4
epoch/learning_rate5e-05
epoch/loss-0.83633
epoch/mean_absolute_error1.08949
epoch/mean_squared_error5.49124
epoch/pearson_correlation0.91258
epoch/pearson_correlation_log0.71104
epoch/val_concordance_correlation_coefficient0.72639
epoch/val_cosine_similarity0.84814
epoch/val_loss-0.58074
epoch/val_mean_absolute_error1.71851
epoch/val_mean_squared_error13.20767
epoch/val_pearson_correlation0.7581
epoch/val_pearson_correlation_log0.61873
epoch/val_zero_penalty_metric309.33755
epoch/zero_penalty_metric304.3938

View run testrun_ft at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac/runs/sdx53oa8
View project at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
Find logs at: ./wandb/run-20250214_131001-sdx53oa8/logs
The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.

Evaluate model#

We’ll evaluate both the finetuned and further finetuned models.

model = keras.models.load_model("biccn_borzoi_atac/testrun/checkpoints/03.keras")
model_ft = keras.models.load_model("biccn_borzoi_atac/testrun_ft/checkpoints/02.keras")
# add predictions for model checkpoint to the adata
adata.layers["model"] = crested.tl.predict(adata, model_ft).T
adata.layers["model_ft"] = crested.tl.predict(adata, model).T
91002/91002 ━━━━━━━━━━━━━━━━━━━━ 318s 3ms/step
91002/91002 ━━━━━━━━━━━━━━━━━━━━ 318s 3ms/step

If you don’t want to predict on the entire dataset, you can also predict on a given sequence or region by passing coordinates, a DNA sequence, or a one-hot encoded sequence to predict().

Many of the plotting functions in the crested.pl module can be used to visualize these model predictions.

# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df
chr start end split
region
chr1:3094031-3096079 chr1 3094031 3096079 test
chr1:3094696-3096744 chr1 3094696 3096744 test
chr1:3112760-3114808 chr1 3112760 3114808 test
chr1:3133812-3135860 chr1 3133812 3135860 test
chr1:3164934-3166982 chr1 3164934 3166982 test
... ... ... ... ...
chrX:20751180-20753228 chrX 20751180 20753228 test
chrX:21388522-21390570 chrX 21388522 21390570 test
chrX:21392726-21394774 chrX 21392726 21394774 test
chrX:21427413-21429461 chrX 21427413 21429461 test
chrX:21433814-21435862 chrX 21433814 21435862 test

10881 rows × 4 columns

# plot predictions vs ground truth for a random region in the test set defined by index
%matplotlib inline
idx = 22
region = test_df.index[idx]
print(region)
crested.pl.bar.region_predictions(
    adata, region, title=f"Predictions vs ground truth ({region})"
)
chr1:3899526-3901574
2025-02-14T13:37:41.075121+0100 INFO Plotting bar plots for region: chr1:3899526-3901574, models: ['model', 'model_ft']
../_images/ad26b30bf76125284b56b25620d86f3a3521e73b09f63c17c789f0c2798d992f.png
crested.pl.heatmap.correlations_self(
    adata,
    title="Self-correlation heatmap",
    x_label_rotation=90,
    width=10,
    height=10,
    vmin=-1,
    vmax=1,
)
../_images/279c0e708fe8146ce2cdda56cbe10dda2733d7c52a4fa15591cff7d65bf65b9f.png
crested.pl.heatmap.correlations_predictions(
    adata,
    split="test",
    title="Correlations between ground truths and predictions",
    x_label_rotation=90,
    width=20,
    height=10,
    log_transform=False,
    # vmin = -1,
    # vmax = 1,
)
2025-02-14T13:38:10.984708+0100 INFO Plotting heatmap correlations for split: test, models: ['model', 'model_ft']
../_images/57a07f6e59dd2bcdb13ddffc74f565686cac902dc76fb4c527c14fec0264c50b.png
crested.pl.scatter.class_density(
    adata,
    split="test",
    log_transform=True,
    width=10,
    height=5,
)
2025-02-14T13:39:20.714235+0100 INFO Plotting density scatter for all targets and predictions, models: ['model', 'model_ft'], split: test
../_images/39b45cd5e18d37059a0a68916b4b97c1a11a16d5e17ccb0bfbd872adb636c926.png

Besides looking at prediction scores, we can also use these models to explain the features in the sequence that contributed to predicted accessibility in a certain cell type.
Here, we’ll look at three regions, expected to be active in microglia (Micro_PVM), Sst/Chodl GABAergic neurons (SstChodl), or in layer 6b glutamatergic neurons (L6b) respectively.

regions_of_interest = [
    "chr18:61107803-61109851",
    "chr13:92952218-92954266",
    "chr9:56036511-56038559",
]
classes_of_interest = ["Micro_PVM", "SstChodl", "L6b"]
class_idx = list(adata.obs_names.get_indexer(classes_of_interest))

scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    regions_of_interest,
    target_idx=class_idx,
    model=model_ft,
)
2025-02-14T13:57:00.326922+0100 INFO Calculating contribution scores for 3 class(es) and 3 region(s).
# Plot attribution scores
crested.pl.patterns.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=regions_of_interest,
    class_labels=classes_of_interest,
    zoom_n_bases=500,
)
../_images/ade8f9cde3bd1e8de07d4750492da0dce543b5856a358ee45105715fa0bea0f1.png