Finetune Borzoi for scATAC peaks#
In this tutorial, we’ll show how to finetune the Borzoi model to do peak regression on scATAC data.
import os
import zipfile
import tempfile
import pandas as pd
import keras
import crested
resources_dir = "/staging/leuven/res_00001/genomes/mus_musculus/mm10_ucsc/fasta/"
genome_file = os.path.join(resources_dir, "mm10.fa")
chromsizes_file = os.path.join(resources_dir, "mm10.chrom.sizes")
folds_file = "consensus_peaks_borsplit.bed"
genome = crested.Genome(genome_file, chromsizes_file)
crested.register_genome(genome)
2025-02-14T13:07:13.655587+0100 INFO Genome mm10 registered.
Read in scATAC data#
We’ll use the same dataset as used in the default tutorial, the mouse BICCN dataset, derived from the brain cortex.
bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
adata = crested.import_bigwigs(
bigwigs_folder=bigwigs_folder,
regions_file=regions_file,
target_region_width=1000,
target="count",
)
adata
2025-02-14T13:08:31.442981+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
obs: 'file_path'
var: 'chr', 'start', 'end'
Add train/val/test split#
Generally, for finetuning, it’s recommended to use the train/test split from the original model, like Borzoi here.
This can be derived by intersecting your consensus peaks with sequences_mouse.bed
from the Borzoi repository, like with BEDTools:
regions_file="consensus_peaks_biccn.bed" # regions_file from crested.get_dataset()
folds_file="sequences_mouse.bed" # From Borzoi repo
output_file="consensus_peaks_borsplit.bed"
grep fold3 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'test/ > ${output_file}
grep fold4 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'val/ >> ${output_file}
for i in 0 1 2 5 6 7; do
grep fold${i} ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'train/ >> ${output_file}
done
folds = pd.read_csv(
folds_file, sep="\t", names=["name", "split"], usecols=[3, 4]
).set_index("name")
print(
f"% of regions found in folds file: {adata.var_names.isin(folds.index).sum()/adata.n_vars*100:.3f}%"
)
% of regions found in folds file: 99.425%
# Drop regions not in any folds
print(
f"Dropping {(~adata.var_names.isin(folds.index)).sum()} regions because they are not in any fold."
)
adata = adata[:, adata.var_names.isin(folds.index)].copy()
# Add fold data to var
adata.var = adata.var.join(folds)
# Check result
adata.var["split"].value_counts(dropna=False)
Dropping 3146 regions because they are not in any fold.
split
train 412229
val 72744
test 58874
Name: count, dtype: int64
Alternatively, you could use the default train/test split function set a chromosome-based or random split:
# crested.pp.train_val_test_split(
# adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
# )
Preprocessing#
For the preprocessing, we’ll again follow the default steps, except for the adjusted input size.
Region width#
In this example, we’ll use 2048bp inputs, to align with the 2114bp input size of the standard CNN peak regression models while staying within a multiple of 128. Therefore, we’ll need to resize our regions:
crested.pp.change_regions_width(
adata,
2048,
)
Peak normalization#
We can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks()
function.
This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.
crested.pp.normalize_peaks(
adata, top_k_percent=0.03
) # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01
2025-02-14T13:09:21.187809+0100 INFO Filtering on top k Gini scores...
2025-02-14T13:09:28.069781+0100 INFO Added normalization weights to adata.obsm['weights']...
chr | start | end | split | |
---|---|---|---|---|
region | ||||
chr9:76566142-76568190 | chr9 | 76566142 | 76568190 | train |
chr5:98328510-98330558 | chr5 | 98328510 | 98330558 | train |
chr5:98347819-98349867 | chr5 | 98347819 | 98349867 | train |
chr13:34635167-34637215 | chr13 | 34635167 | 34637215 | train |
chr13:34642109-34644157 | chr13 | 34642109 | 34644157 | train |
... | ... | ... | ... | ... |
chr13:34344270-34346318 | chr13 | 34344270 | 34346318 | train |
chr5:98166140-98168188 | chr5 | 98166140 | 98168188 | train |
chr5:98166667-98168715 | chr5 | 98166667 | 98168715 | train |
chr13:34344974-34347022 | chr13 | 34344974 | 34347022 | train |
chr5:98185712-98187760 | chr5 | 98185712 | 98187760 | train |
48075 rows × 4 columns
We can visualize the normalization factor for each cell type using the crested.pl.bar.normalization_weights()
function to inspect which cell type peaks were up/down weighted.
%matplotlib inline
crested.pl.bar.normalization_weights(adata, title="Normalization weights per cell type")

Load in model#
We load in the Borzoi model’s weights in its architecture, with one change - the input length. All of Borzoi’s layers are width-independent, so the length can be set to any value divisible by the internal bin size (128bp).
target_length
is set to the total number of output bins (64 bins of 32bp makes 2048bp output), since no cropping is needed when predicting local features.
num_classes
is set to the original size simply so that there are no weight shape mismatches when loading the initial weights; the head created based on num_classes
will be replaced by a new head for the number of cell types we’d like to predict below.
# Create default Borzoi architecture, with shrunk input size and target_length
base_model_architecture = crested.tl.zoo.borzoi(
seq_len=2048, target_length=2048 // 32, num_classes=2608
)
To load in the weights, we can’t directly load the model from the .keras
file, since that fixes the input length at the previously set value (524288bp). However, we can extract the model.weights.h5
file containing only the weights and use that.
# Load pretrained Borzoi weights
model_file, _ = crested.get_model("Borzoi_mouse_rep0")
# Put weights into base architecture
with zipfile.ZipFile(
model_file
) as model_archive, tempfile.TemporaryDirectory() as tmpdir:
model_weights_path = model_archive.extract("model.weights.h5", tmpdir)
base_model_architecture.load_weights(model_weights_path)
Now that we have the base model with the adjusted input shape, we need to adjust the final layers to return a value for each cell type per region, instead of per-bin values. Therefore, we drop the final head, add a flatten layer after the model’s final embedding, and add a new head predicting adata.n_obs
values.
# Replace track head by flatten+dense to predict single vector of scalars per region
## Get last layer before head
current = base_model_architecture.get_layer("final_conv_activation").output
## Flatten and add new layer
current = keras.layers.Flatten()(current)
current = keras.layers.Dense(adata.n_obs, activation="softplus", name="dense_out")(
current
)
# Turn into model
model_architecture = keras.Model(
inputs=base_model_architecture.inputs, outputs=current, name="Borzoi_scalar"
)
print(model_architecture.summary())
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input (InputLayer) │ (None, 2048, 4) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv1D) │ (None, 2048, 512) │ 31,232 │ input[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_pool │ (None, 1024, 512) │ 0 │ stem_conv[0][0] │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_batch… │ (None, 1024, 512) │ 2,048 │ stem_pool[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_activ… │ (None, 1024, 512) │ 0 │ tower_conv_1_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_conv │ (None, 1024, 608) │ 1,557,088 │ tower_conv_1_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_pool │ (None, 512, 608) │ 0 │ tower_conv_1_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_batch… │ (None, 512, 608) │ 2,432 │ tower_conv_1_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_activ… │ (None, 512, 608) │ 0 │ tower_conv_2_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_conv │ (None, 512, 736) │ 2,238,176 │ tower_conv_2_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_pool │ (None, 256, 736) │ 0 │ tower_conv_2_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_batch… │ (None, 256, 736) │ 2,944 │ tower_conv_2_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_activ… │ (None, 256, 736) │ 0 │ tower_conv_3_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_conv │ (None, 256, 896) │ 3,298,176 │ tower_conv_3_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_pool │ (None, 128, 896) │ 0 │ tower_conv_3_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_batch… │ (None, 128, 896) │ 3,584 │ tower_conv_3_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_activ… │ (None, 128, 896) │ 0 │ tower_conv_4_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_conv │ (None, 128, 1056) │ 4,731,936 │ tower_conv_4_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_pool │ (None, 64, 1056) │ 0 │ tower_conv_4_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_batch… │ (None, 64, 1056) │ 4,224 │ tower_conv_4_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_activ… │ (None, 64, 1056) │ 0 │ tower_conv_5_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_conv │ (None, 64, 1280) │ 6,759,680 │ tower_conv_5_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_pool │ (None, 32, 1280) │ 0 │ tower_conv_5_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_batch… │ (None, 32, 1280) │ 5,120 │ tower_conv_5_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_activ… │ (None, 32, 1280) │ 0 │ tower_conv_6_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_conv │ (None, 32, 1536) │ 9,831,936 │ tower_conv_6_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_pool │ (None, 16, 1536) │ 0 │ tower_conv_6_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 3,072 │ tower_conv_6_poo… │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add (Add) │ (None, 16, 1536) │ 0 │ tower_conv_6_poo… │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_l… │ (None, 16, 1536) │ 3,072 │ add[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_a… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 1536) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_1 (Add) │ (None, 16, 1536) │ 0 │ add[0][0], │ │ │ │ │ transformer_ff_1… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 3,072 │ add_1[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_2 (Add) │ (None, 16, 1536) │ 0 │ add_1[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_l… │ (None, 16, 1536) │ 3,072 │ add_2[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_a… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 1536) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_3 (Add) │ (None, 16, 1536) │ 0 │ add_2[0][0], │ │ │ │ │ transformer_ff_2… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 3,072 │ add_3[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_4 (Add) │ (None, 16, 1536) │ 0 │ add_3[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_l… │ (None, 16, 1536) │ 3,072 │ add_4[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_a… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 1536) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_5 (Add) │ (None, 16, 1536) │ 0 │ add_4[0][0], │ │ │ │ │ transformer_ff_3… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 3,072 │ add_5[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_6 (Add) │ (None, 16, 1536) │ 0 │ add_5[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_l… │ (None, 16, 1536) │ 3,072 │ add_6[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_a… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 1536) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_7 (Add) │ (None, 16, 1536) │ 0 │ add_6[0][0], │ │ │ │ │ transformer_ff_4… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 3,072 │ add_7[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_8 (Add) │ (None, 16, 1536) │ 0 │ add_7[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_l… │ (None, 16, 1536) │ 3,072 │ add_8[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_a… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 1536) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_9 (Add) │ (None, 16, 1536) │ 0 │ add_8[0][0], │ │ │ │ │ transformer_ff_5… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 3,072 │ add_9[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_10 (Add) │ (None, 16, 1536) │ 0 │ add_9[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_l… │ (None, 16, 1536) │ 3,072 │ add_10[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_a… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 1536) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_11 (Add) │ (None, 16, 1536) │ 0 │ add_10[0][0], │ │ │ │ │ transformer_ff_6… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 3,072 │ add_11[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_12 (Add) │ (None, 16, 1536) │ 0 │ add_11[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_l… │ (None, 16, 1536) │ 3,072 │ add_12[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_a… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 1536) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_13 (Add) │ (None, 16, 1536) │ 0 │ add_12[0][0], │ │ │ │ │ transformer_ff_7… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 3,072 │ add_13[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_14 (Add) │ (None, 16, 1536) │ 0 │ add_13[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_l… │ (None, 16, 1536) │ 3,072 │ add_14[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_a… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 1536) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_15 (Add) │ (None, 16, 1536) │ 0 │ add_14[0][0], │ │ │ │ │ transformer_ff_8… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 6,144 │ add_15[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_batchn… │ (None, 32, 1536) │ 6,144 │ tower_conv_6_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_activa… │ (None, 32, 1536) │ 0 │ unet_skip_2_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_conv │ (None, 32, 1536) │ 2,360,832 │ unet_skip_2_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_16 (Add) │ (None, 32, 1536) │ 0 │ up_sampling1d[0]… │ │ │ │ │ unet_skip_2_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 32, 1536) │ 2,365,440 │ add_16[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_batchn… │ (None, 64, 1280) │ 5,120 │ tower_conv_5_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_activa… │ (None, 64, 1280) │ 0 │ unet_skip_1_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d_1 │ (None, 64, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_conv │ (None, 64, 1536) │ 1,967,616 │ unet_skip_1_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_17 (Add) │ (None, 64, 1536) │ 0 │ up_sampling1d_1[… │ │ │ │ │ unet_skip_1_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 64, 1536) │ 2,365,440 │ add_17[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_batchno… │ (None, 64, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_activat… │ (None, 64, 1536) │ 0 │ final_conv_batch… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ flatten (Flatten) │ (None, 98304) │ 0 │ final_conv_activ… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_out (Dense) │ (None, 19) │ 1,867,795 │ flatten[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 170,213,747 (649.31 MB)
Trainable params: 170,188,723 (649.22 MB)
Non-trainable params: 25,024 (97.75 KB)
None
Model Training#
Parameters#
The DataModule and TaskConfig let you set standard training parameters, like batch size and learning rate.
We use the same parameters as with peak regression in the default tutorial, except for a lower learning rate to match the fact that we are starting from a pre-trained model.
datamodule = crested.tl.data.AnnDataModule(
adata,
genome=genome,
batch_size=32, # lower this if you encounter OOM errors
max_stochastic_shift=3, # optional augmentation
always_reverse_complement=True, # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
keras.metrics.MeanAbsoluteError(),
keras.metrics.MeanSquaredError(),
keras.metrics.CosineSimilarity(axis=1),
crested.tl.metrics.PearsonCorrelation(),
crested.tl.metrics.ConcordanceCorrelationCoefficient(),
crested.tl.metrics.PearsonCorrelationLog(),
crested.tl.metrics.ZeroPenaltyMetric(),
]
config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x1511019ba0d0>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x1511019bf990>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
Finetune on full peak set#
By default:
The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.
Every best model is saved based on the validation loss.
The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.
# setup the trainer
trainer = crested.tl.Crested(
data=datamodule,
model=model_architecture,
config=config,
project_name="biccn_borzoi_atac",
run_name="testrun",
logger="wandb",
)
# train the model
trainer.fit(epochs=10)
/lustre1/project/stg_00002/lcb/cblaauw/rna_models/borzoi_atac/wandb/run-20250213_161634-veap7cez
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input (InputLayer) │ (None, 2048, 4) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv1D) │ (None, 2048, 512) │ 31,232 │ input[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_pool │ (None, 1024, 512) │ 0 │ stem_conv[0][0] │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_batch… │ (None, 1024, 512) │ 2,048 │ stem_pool[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_activ… │ (None, 1024, 512) │ 0 │ tower_conv_1_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_conv │ (None, 1024, 608) │ 1,557,088 │ tower_conv_1_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_pool │ (None, 512, 608) │ 0 │ tower_conv_1_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_batch… │ (None, 512, 608) │ 2,432 │ tower_conv_1_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_activ… │ (None, 512, 608) │ 0 │ tower_conv_2_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_conv │ (None, 512, 736) │ 2,238,176 │ tower_conv_2_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_pool │ (None, 256, 736) │ 0 │ tower_conv_2_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_batch… │ (None, 256, 736) │ 2,944 │ tower_conv_2_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_activ… │ (None, 256, 736) │ 0 │ tower_conv_3_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_conv │ (None, 256, 896) │ 3,298,176 │ tower_conv_3_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_pool │ (None, 128, 896) │ 0 │ tower_conv_3_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_batch… │ (None, 128, 896) │ 3,584 │ tower_conv_3_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_activ… │ (None, 128, 896) │ 0 │ tower_conv_4_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_conv │ (None, 128, 1056) │ 4,731,936 │ tower_conv_4_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_pool │ (None, 64, 1056) │ 0 │ tower_conv_4_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_batch… │ (None, 64, 1056) │ 4,224 │ tower_conv_4_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_activ… │ (None, 64, 1056) │ 0 │ tower_conv_5_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_conv │ (None, 64, 1280) │ 6,759,680 │ tower_conv_5_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_pool │ (None, 32, 1280) │ 0 │ tower_conv_5_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_batch… │ (None, 32, 1280) │ 5,120 │ tower_conv_5_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_activ… │ (None, 32, 1280) │ 0 │ tower_conv_6_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_conv │ (None, 32, 1536) │ 9,831,936 │ tower_conv_6_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_pool │ (None, 16, 1536) │ 0 │ tower_conv_6_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 3,072 │ tower_conv_6_poo… │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_18 (Add) │ (None, 16, 1536) │ 0 │ tower_conv_6_poo… │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_l… │ (None, 16, 1536) │ 3,072 │ add_18[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_a… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 1536) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_19 (Add) │ (None, 16, 1536) │ 0 │ add_18[0][0], │ │ │ │ │ transformer_ff_1… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 3,072 │ add_19[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_20 (Add) │ (None, 16, 1536) │ 0 │ add_19[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_l… │ (None, 16, 1536) │ 3,072 │ add_20[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_a… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 1536) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_21 (Add) │ (None, 16, 1536) │ 0 │ add_20[0][0], │ │ │ │ │ transformer_ff_2… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 3,072 │ add_21[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_22 (Add) │ (None, 16, 1536) │ 0 │ add_21[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_l… │ (None, 16, 1536) │ 3,072 │ add_22[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_a… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 1536) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_23 (Add) │ (None, 16, 1536) │ 0 │ add_22[0][0], │ │ │ │ │ transformer_ff_3… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 3,072 │ add_23[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_24 (Add) │ (None, 16, 1536) │ 0 │ add_23[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_l… │ (None, 16, 1536) │ 3,072 │ add_24[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_a… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 1536) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_25 (Add) │ (None, 16, 1536) │ 0 │ add_24[0][0], │ │ │ │ │ transformer_ff_4… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 3,072 │ add_25[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_26 (Add) │ (None, 16, 1536) │ 0 │ add_25[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_l… │ (None, 16, 1536) │ 3,072 │ add_26[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_a… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 1536) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_27 (Add) │ (None, 16, 1536) │ 0 │ add_26[0][0], │ │ │ │ │ transformer_ff_5… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 3,072 │ add_27[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_28 (Add) │ (None, 16, 1536) │ 0 │ add_27[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_l… │ (None, 16, 1536) │ 3,072 │ add_28[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_a… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 1536) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_29 (Add) │ (None, 16, 1536) │ 0 │ add_28[0][0], │ │ │ │ │ transformer_ff_6… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 3,072 │ add_29[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_30 (Add) │ (None, 16, 1536) │ 0 │ add_29[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_l… │ (None, 16, 1536) │ 3,072 │ add_30[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_a… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 1536) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_31 (Add) │ (None, 16, 1536) │ 0 │ add_30[0][0], │ │ │ │ │ transformer_ff_7… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 3,072 │ add_31[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_32 (Add) │ (None, 16, 1536) │ 0 │ add_31[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_l… │ (None, 16, 1536) │ 3,072 │ add_32[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_a… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 1536) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_33 (Add) │ (None, 16, 1536) │ 0 │ add_32[0][0], │ │ │ │ │ transformer_ff_8… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 6,144 │ add_33[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_batchn… │ (None, 32, 1536) │ 6,144 │ tower_conv_6_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_activa… │ (None, 32, 1536) │ 0 │ unet_skip_2_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d_2 │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_conv │ (None, 32, 1536) │ 2,360,832 │ unet_skip_2_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_34 (Add) │ (None, 32, 1536) │ 0 │ up_sampling1d_2[… │ │ │ │ │ unet_skip_2_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 32, 1536) │ 2,365,440 │ add_34[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_batchn… │ (None, 64, 1280) │ 5,120 │ tower_conv_5_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_activa… │ (None, 64, 1280) │ 0 │ unet_skip_1_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d_3 │ (None, 64, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_conv │ (None, 64, 1536) │ 1,967,616 │ unet_skip_1_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_35 (Add) │ (None, 64, 1536) │ 0 │ up_sampling1d_3[… │ │ │ │ │ unet_skip_1_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 64, 1536) │ 2,365,440 │ add_35[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_batchno… │ (None, 64, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_activat… │ (None, 64, 1536) │ 0 │ final_conv_batch… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ flatten (Flatten) │ (None, 98304) │ 0 │ final_conv_activ… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_out (Dense) │ (None, 19) │ 1,867,795 │ flatten[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 170,213,747 (649.31 MB)
Trainable params: 170,188,723 (649.22 MB)
Non-trainable params: 25,024 (97.75 KB)
None
2025-02-13T16:17:01.947660+0100 INFO Loading sequences into memory...
2025-02-13T16:17:10.887463+0100 INFO Loading sequences into memory...
Epoch 1/10
25764/25765 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2751 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2008
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2748 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2008
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 969s 34ms/step - concordance_correlation_coefficient: 0.7620 - cosine_similarity: 0.8605 - loss: -0.5721 - mean_absolute_error: 2.5347 - mean_squared_error: 28.2745 - pearson_correlation: 0.8269 - pearson_correlation_log: 0.6319 - zero_penalty_metric: 135.2007 - val_concordance_correlation_coefficient: 0.8757 - val_cosine_similarity: 0.8763 - val_loss: -0.6293 - val_mean_absolute_error: 2.1752 - val_mean_squared_error: 18.0886 - val_pearson_correlation: 0.8776 - val_pearson_correlation_log: 0.6524 - val_zero_penalty_metric: 138.8675 - learning_rate: 1.0000e-05
Epoch 2/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 832s 32ms/step - concordance_correlation_coefficient: 0.8893 - cosine_similarity: 0.8847 - loss: -0.6589 - mean_absolute_error: 2.0946 - mean_squared_error: 15.5250 - pearson_correlation: 0.9042 - pearson_correlation_log: 0.6722 - zero_penalty_metric: 133.3652 - val_concordance_correlation_coefficient: 0.8768 - val_cosine_similarity: 0.8804 - val_loss: -0.6392 - val_mean_absolute_error: 2.1034 - val_mean_squared_error: 16.3852 - val_pearson_correlation: 0.8845 - val_pearson_correlation_log: 0.6555 - val_zero_penalty_metric: 139.2296 - learning_rate: 1.0000e-05
Epoch 3/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 825s 32ms/step - concordance_correlation_coefficient: 0.9066 - cosine_similarity: 0.8926 - loss: -0.6849 - mean_absolute_error: 1.9653 - mean_squared_error: 13.2411 - pearson_correlation: 0.9170 - pearson_correlation_log: 0.6835 - zero_penalty_metric: 132.5512 - val_concordance_correlation_coefficient: 0.8715 - val_cosine_similarity: 0.8828 - val_loss: -0.6440 - val_mean_absolute_error: 2.0995 - val_mean_squared_error: 16.5556 - val_pearson_correlation: 0.8848 - val_pearson_correlation_log: 0.6619 - val_zero_penalty_metric: 138.6013 - learning_rate: 1.0000e-05
Epoch 4/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 818s 32ms/step - concordance_correlation_coefficient: 0.9192 - cosine_similarity: 0.8986 - loss: -0.7056 - mean_absolute_error: 1.8737 - mean_squared_error: 11.8145 - pearson_correlation: 0.9268 - pearson_correlation_log: 0.6924 - zero_penalty_metric: 131.7001 - val_concordance_correlation_coefficient: 0.8761 - val_cosine_similarity: 0.8827 - val_loss: -0.6397 - val_mean_absolute_error: 2.0982 - val_mean_squared_error: 16.2530 - val_pearson_correlation: 0.8864 - val_pearson_correlation_log: 0.6568 - val_zero_penalty_metric: 138.3267 - learning_rate: 1.0000e-05
Epoch 5/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 826s 32ms/step - concordance_correlation_coefficient: 0.9276 - cosine_similarity: 0.9040 - loss: -0.7238 - mean_absolute_error: 1.7957 - mean_squared_error: 10.7533 - pearson_correlation: 0.9339 - pearson_correlation_log: 0.7007 - zero_penalty_metric: 130.6058 - val_concordance_correlation_coefficient: 0.8795 - val_cosine_similarity: 0.8823 - val_loss: -0.6346 - val_mean_absolute_error: 2.1381 - val_mean_squared_error: 16.7327 - val_pearson_correlation: 0.8824 - val_pearson_correlation_log: 0.6552 - val_zero_penalty_metric: 139.1519 - learning_rate: 1.0000e-05
Epoch 6/10
6901/25765 ━━━━━━━━━━━━━━━━━━━━ 9:44 31ms/step - concordance_correlation_coefficient: 0.9324 - cosine_similarity: 0.9092 - loss: -0.7411 - mean_absolute_error: 1.7301 - mean_squared_error: 9.9792 - pearson_correlation: 0.9379 - pearson_correlation_log: 0.7097 - zero_penalty_metric: 128.9625
Further finetuning on specific regions#
We found that finetuning on the full peak set, then on the filtered peak set improved performance over training only on either set. Therefore, we’ll filter the peaks to keep only cell type-specific peaks and further finetune the model.
crested.pp.filter_regions_on_specificity(adata, gini_std_threshold=1.0)
2025-02-14T13:09:51.664914+0100 INFO After specificity filtering, kept 91002 out of 543847 regions.
datamodule = crested.tl.data.AnnDataModule(
adata,
genome=genome,
batch_size=32, # lower this if you encounter OOM errors
max_stochastic_shift=3, # optional augmentation
always_reverse_complement=True, # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=5e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
keras.metrics.MeanAbsoluteError(),
keras.metrics.MeanSquaredError(),
keras.metrics.CosineSimilarity(axis=1),
crested.tl.metrics.PearsonCorrelation(),
crested.tl.metrics.ConcordanceCorrelationCoefficient(),
crested.tl.metrics.PearsonCorrelationLog(),
crested.tl.metrics.ZeroPenaltyMetric(),
]
config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x151101c6d890>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x151101c6e010>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
model_architecture = keras.models.load_model(
"biccn_borzoi_atac/testrun/checkpoints/03.keras", compile=False
)
# setup the trainer
trainer = crested.tl.Crested(
data=datamodule,
model=model_architecture,
config=config,
project_name="biccn_borzoi_atac",
run_name="testrun_ft",
logger="wandb",
)
# train the model
trainer.fit(epochs=5)
/lustre1/project/stg_00002/lcb/cblaauw/rna_models/borzoi_atac/wandb/run-20250214_131001-sdx53oa8
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input (InputLayer) │ (None, 2048, 4) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv1D) │ (None, 2048, 512) │ 31,232 │ input[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_pool │ (None, 1024, 512) │ 0 │ stem_conv[0][0] │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_batch… │ (None, 1024, 512) │ 2,048 │ stem_pool[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_activ… │ (None, 1024, 512) │ 0 │ tower_conv_1_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_conv │ (None, 1024, 608) │ 1,557,088 │ tower_conv_1_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_1_pool │ (None, 512, 608) │ 0 │ tower_conv_1_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_batch… │ (None, 512, 608) │ 2,432 │ tower_conv_1_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_activ… │ (None, 512, 608) │ 0 │ tower_conv_2_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_conv │ (None, 512, 736) │ 2,238,176 │ tower_conv_2_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_2_pool │ (None, 256, 736) │ 0 │ tower_conv_2_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_batch… │ (None, 256, 736) │ 2,944 │ tower_conv_2_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_activ… │ (None, 256, 736) │ 0 │ tower_conv_3_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_conv │ (None, 256, 896) │ 3,298,176 │ tower_conv_3_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_3_pool │ (None, 128, 896) │ 0 │ tower_conv_3_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_batch… │ (None, 128, 896) │ 3,584 │ tower_conv_3_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_activ… │ (None, 128, 896) │ 0 │ tower_conv_4_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_conv │ (None, 128, 1056) │ 4,731,936 │ tower_conv_4_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_4_pool │ (None, 64, 1056) │ 0 │ tower_conv_4_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_batch… │ (None, 64, 1056) │ 4,224 │ tower_conv_4_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_activ… │ (None, 64, 1056) │ 0 │ tower_conv_5_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_conv │ (None, 64, 1280) │ 6,759,680 │ tower_conv_5_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_5_pool │ (None, 32, 1280) │ 0 │ tower_conv_5_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_batch… │ (None, 32, 1280) │ 5,120 │ tower_conv_5_poo… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_activ… │ (None, 32, 1280) │ 0 │ tower_conv_6_bat… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_conv │ (None, 32, 1536) │ 9,831,936 │ tower_conv_6_act… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ tower_conv_6_pool │ (None, 16, 1536) │ 0 │ tower_conv_6_con… │ │ (MaxPooling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 3,072 │ tower_conv_6_poo… │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_1_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_18 (Add) │ (None, 16, 1536) │ 0 │ tower_conv_6_poo… │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_l… │ (None, 16, 1536) │ 3,072 │ add_18[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_a… │ (None, 16, 3072) │ 0 │ transformer_ff_1… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_1… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_1_d… │ (None, 16, 1536) │ 0 │ transformer_ff_1… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_19 (Add) │ (None, 16, 1536) │ 0 │ add_18[0][0], │ │ │ │ │ transformer_ff_1… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 3,072 │ add_19[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_2_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_20 (Add) │ (None, 16, 1536) │ 0 │ add_19[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_l… │ (None, 16, 1536) │ 3,072 │ add_20[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_a… │ (None, 16, 3072) │ 0 │ transformer_ff_2… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_2… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_2_d… │ (None, 16, 1536) │ 0 │ transformer_ff_2… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_21 (Add) │ (None, 16, 1536) │ 0 │ add_20[0][0], │ │ │ │ │ transformer_ff_2… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 3,072 │ add_21[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_3_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_22 (Add) │ (None, 16, 1536) │ 0 │ add_21[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_l… │ (None, 16, 1536) │ 3,072 │ add_22[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_a… │ (None, 16, 3072) │ 0 │ transformer_ff_3… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_3… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_3_d… │ (None, 16, 1536) │ 0 │ transformer_ff_3… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_23 (Add) │ (None, 16, 1536) │ 0 │ add_22[0][0], │ │ │ │ │ transformer_ff_3… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 3,072 │ add_23[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_4_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_24 (Add) │ (None, 16, 1536) │ 0 │ add_23[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_l… │ (None, 16, 1536) │ 3,072 │ add_24[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_a… │ (None, 16, 3072) │ 0 │ transformer_ff_4… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_4… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_4_d… │ (None, 16, 1536) │ 0 │ transformer_ff_4… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_25 (Add) │ (None, 16, 1536) │ 0 │ add_24[0][0], │ │ │ │ │ transformer_ff_4… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 3,072 │ add_25[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_5_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_26 (Add) │ (None, 16, 1536) │ 0 │ add_25[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_l… │ (None, 16, 1536) │ 3,072 │ add_26[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_a… │ (None, 16, 3072) │ 0 │ transformer_ff_5… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_5… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_5_d… │ (None, 16, 1536) │ 0 │ transformer_ff_5… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_27 (Add) │ (None, 16, 1536) │ 0 │ add_26[0][0], │ │ │ │ │ transformer_ff_5… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 3,072 │ add_27[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_6_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_28 (Add) │ (None, 16, 1536) │ 0 │ add_27[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_l… │ (None, 16, 1536) │ 3,072 │ add_28[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_a… │ (None, 16, 3072) │ 0 │ transformer_ff_6… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_6… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_6_d… │ (None, 16, 1536) │ 0 │ transformer_ff_6… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_29 (Add) │ (None, 16, 1536) │ 0 │ add_28[0][0], │ │ │ │ │ transformer_ff_6… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 3,072 │ add_29[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_7_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_30 (Add) │ (None, 16, 1536) │ 0 │ add_29[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_l… │ (None, 16, 1536) │ 3,072 │ add_30[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_a… │ (None, 16, 3072) │ 0 │ transformer_ff_7… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_7… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_7_d… │ (None, 16, 1536) │ 0 │ transformer_ff_7… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_31 (Add) │ (None, 16, 1536) │ 0 │ add_30[0][0], │ │ │ │ │ transformer_ff_7… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 3,072 │ add_31[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 6,310,400 │ transformer_mha_… │ │ (MultiheadAttentio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_mha_8_… │ (None, 16, 1536) │ 0 │ transformer_mha_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_32 (Add) │ (None, 16, 1536) │ 0 │ add_31[0][0], │ │ │ │ │ transformer_mha_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_l… │ (None, 16, 1536) │ 3,072 │ add_32[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 3072) │ 4,721,664 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_a… │ (None, 16, 3072) │ 0 │ transformer_ff_8… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_p… │ (None, 16, 1536) │ 4,720,128 │ transformer_ff_8… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_ff_8_d… │ (None, 16, 1536) │ 0 │ transformer_ff_8… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_33 (Add) │ (None, 16, 1536) │ 0 │ add_32[0][0], │ │ │ │ │ transformer_ff_8… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 6,144 │ add_33[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_batchn… │ (None, 32, 1536) │ 6,144 │ tower_conv_6_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_1_… │ (None, 16, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_activa… │ (None, 32, 1536) │ 0 │ unet_skip_2_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d_2 │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_2_conv │ (None, 32, 1536) │ 2,360,832 │ unet_skip_2_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_34 (Add) │ (None, 32, 1536) │ 0 │ up_sampling1d_2[… │ │ │ │ │ unet_skip_2_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 32, 1536) │ 2,365,440 │ add_34[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 0 │ upsampling_conv_… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_batchn… │ (None, 64, 1280) │ 5,120 │ tower_conv_5_con… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_conv_2_… │ (None, 32, 1536) │ 2,360,832 │ upsampling_conv_… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_activa… │ (None, 64, 1280) │ 0 │ unet_skip_1_batc… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ up_sampling1d_3 │ (None, 64, 1536) │ 0 │ upsampling_conv_… │ │ (UpSampling1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ unet_skip_1_conv │ (None, 64, 1536) │ 1,967,616 │ unet_skip_1_acti… │ │ (Conv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_35 (Add) │ (None, 64, 1536) │ 0 │ up_sampling1d_3[… │ │ │ │ │ unet_skip_1_conv… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ upsampling_separab… │ (None, 64, 1536) │ 2,365,440 │ add_35[0][0] │ │ (SeparableConv1D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_batchno… │ (None, 64, 1536) │ 6,144 │ upsampling_separ… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ final_conv_activat… │ (None, 64, 1536) │ 0 │ final_conv_batch… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ flatten (Flatten) │ (None, 98304) │ 0 │ final_conv_activ… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_out (Dense) │ (None, 19) │ 1,867,795 │ flatten[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 170,213,747 (649.31 MB)
Trainable params: 170,188,723 (649.22 MB)
Non-trainable params: 25,024 (97.75 KB)
None
2025-02-14T13:10:34.604198+0100 INFO Loading sequences into memory...
2025-02-14T13:10:45.359384+0100 INFO Loading sequences into memory...
Epoch 1/5
4197/4198 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4116 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1437
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4115 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1441
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 288s 48ms/step - concordance_correlation_coefficient: 0.7256 - cosine_similarity: 0.8709 - loss: -0.6439 - mean_absolute_error: 1.6038 - mean_squared_error: 12.4114 - pearson_correlation: 0.7944 - pearson_correlation_log: 0.6196 - zero_penalty_metric: 320.1445 - val_concordance_correlation_coefficient: 0.6791 - val_cosine_similarity: 0.8512 - val_loss: -0.5814 - val_mean_absolute_error: 1.7332 - val_mean_squared_error: 14.4574 - val_pearson_correlation: 0.7423 - val_pearson_correlation_log: 0.6029 - val_zero_penalty_metric: 312.9736 - learning_rate: 5.0000e-05
Epoch 2/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 139s 33ms/step - concordance_correlation_coefficient: 0.8034 - cosine_similarity: 0.9018 - loss: -0.7173 - mean_absolute_error: 1.4257 - mean_squared_error: 9.5457 - pearson_correlation: 0.8455 - pearson_correlation_log: 0.6463 - zero_penalty_metric: 318.2404 - val_concordance_correlation_coefficient: 0.7234 - val_cosine_similarity: 0.8542 - val_loss: -0.5963 - val_mean_absolute_error: 1.7099 - val_mean_squared_error: 13.3615 - val_pearson_correlation: 0.7549 - val_pearson_correlation_log: 0.6074 - val_zero_penalty_metric: 307.6308 - learning_rate: 5.0000e-05
Epoch 3/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8495 - cosine_similarity: 0.9229 - loss: -0.7697 - mean_absolute_error: 1.2824 - mean_squared_error: 7.5957 - pearson_correlation: 0.8779 - pearson_correlation_log: 0.6710 - zero_penalty_metric: 312.4368 - val_concordance_correlation_coefficient: 0.7091 - val_cosine_similarity: 0.8496 - val_loss: -0.5783 - val_mean_absolute_error: 1.7447 - val_mean_squared_error: 13.5709 - val_pearson_correlation: 0.7505 - val_pearson_correlation_log: 0.6042 - val_zero_penalty_metric: 309.9329 - learning_rate: 5.0000e-05
Epoch 4/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8773 - cosine_similarity: 0.9389 - loss: -0.8093 - mean_absolute_error: 1.1699 - mean_squared_error: 6.3396 - pearson_correlation: 0.8985 - pearson_correlation_log: 0.6910 - zero_penalty_metric: 309.6932 - val_concordance_correlation_coefficient: 0.7358 - val_cosine_similarity: 0.8513 - val_loss: -0.5890 - val_mean_absolute_error: 1.7230 - val_mean_squared_error: 13.0675 - val_pearson_correlation: 0.7589 - val_pearson_correlation_log: 0.6259 - val_zero_penalty_metric: 314.7453 - learning_rate: 5.0000e-05
Epoch 5/5
4198/4198 ━━━━━━━━━━━━━━━━━━━━ 136s 32ms/step - concordance_correlation_coefficient: 0.8976 - cosine_similarity: 0.9504 - loss: -0.8389 - mean_absolute_error: 1.0856 - mean_squared_error: 5.4635 - pearson_correlation: 0.9137 - pearson_correlation_log: 0.7117 - zero_penalty_metric: 304.2629 - val_concordance_correlation_coefficient: 0.7264 - val_cosine_similarity: 0.8481 - val_loss: -0.5807 - val_mean_absolute_error: 1.7185 - val_mean_squared_error: 13.2077 - val_pearson_correlation: 0.7581 - val_pearson_correlation_log: 0.6187 - val_zero_penalty_metric: 309.3376 - learning_rate: 5.0000e-05
Run history:
batch/batch_step | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███ |
batch/concordance_correlation_coefficient | ▁▁▂▂▂▂▂▂▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████ |
batch/cosine_similarity | ▁▁▁▁▁▁▁▁▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████ |
batch/loss | ████████▅▅▅▅▅▅▅▅▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁ |
batch/mean_absolute_error | ████████▆▅▅▅▅▅▅▅▄▃▄▄▄▄▄▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁ |
batch/mean_squared_error | ███▇▇▇▇▇▅▅▅▅▅▅▅▅▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁ |
batch/pearson_correlation | ▁▁▂▂▂▂▂▂▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████ |
batch/pearson_correlation_log | ▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████ |
batch/zero_penalty_metric | █▇▇▇████▇█▇▇▇▇▇▇▄▅▅▅▅▅▅▆▄▄▄▄▄▄▄▄▁▂▃▂▃▃▂▂ |
epoch/concordance_correlation_coefficient | ▁▁▄▄▆▆▇▇██ |
epoch/cosine_similarity | ▁▁▄▄▆▆▇▇██ |
epoch/epoch | ▁▁▃▃▅▅▆▆██ |
epoch/learning_rate | ▁▁▁▁▁▁▁▁▁▁ |
epoch/loss | ██▅▅▄▄▂▂▁▁ |
epoch/mean_absolute_error | ██▆▆▄▄▂▂▁▁ |
epoch/mean_squared_error | ██▅▅▃▃▂▂▁▁ |
epoch/pearson_correlation | ▁▁▄▄▆▆▇▇██ |
epoch/pearson_correlation_log | ▁▁▃▃▅▅▆▆██ |
epoch/val_concordance_correlation_coefficient | ▁▁▆▆▅▅██▇▇ |
epoch/val_cosine_similarity | ▄▄██▃▃▅▅▁▁ |
epoch/val_loss | ▇▇▁▁██▄▄▇▇ |
epoch/val_mean_absolute_error | ▆▆▁▁██▄▄▃▃ |
epoch/val_mean_squared_error | ██▂▂▄▄▁▁▂▂ |
epoch/val_pearson_correlation | ▁▁▆▆▄▄████ |
epoch/val_pearson_correlation_log | ▁▁▂▂▁▁██▆▆ |
epoch/val_zero_penalty_metric | ▆▆▁▁▃▃██▃▃ |
epoch/zero_penalty_metric | ██▇▇▅▅▃▃▁▁ |
Run summary:
batch/batch_step | 20990 |
batch/concordance_correlation_coefficient | 0.89607 |
batch/cosine_similarity | 0.9493 |
batch/loss | -0.83633 |
batch/mean_absolute_error | 1.08951 |
batch/mean_squared_error | 5.49214 |
batch/pearson_correlation | 0.91258 |
batch/pearson_correlation_log | 0.71105 |
batch/zero_penalty_metric | 304.45786 |
epoch/concordance_correlation_coefficient | 0.89606 |
epoch/cosine_similarity | 0.9493 |
epoch/epoch | 4 |
epoch/learning_rate | 5e-05 |
epoch/loss | -0.83633 |
epoch/mean_absolute_error | 1.08949 |
epoch/mean_squared_error | 5.49124 |
epoch/pearson_correlation | 0.91258 |
epoch/pearson_correlation_log | 0.71104 |
epoch/val_concordance_correlation_coefficient | 0.72639 |
epoch/val_cosine_similarity | 0.84814 |
epoch/val_loss | -0.58074 |
epoch/val_mean_absolute_error | 1.71851 |
epoch/val_mean_squared_error | 13.20767 |
epoch/val_pearson_correlation | 0.7581 |
epoch/val_pearson_correlation_log | 0.61873 |
epoch/val_zero_penalty_metric | 309.33755 |
epoch/zero_penalty_metric | 304.3938 |
View project at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
./wandb/run-20250214_131001-sdx53oa8/logs
Evaluate model#
We’ll evaluate both the finetuned and further finetuned models.
model = keras.models.load_model("biccn_borzoi_atac/testrun/checkpoints/03.keras")
model_ft = keras.models.load_model("biccn_borzoi_atac/testrun_ft/checkpoints/02.keras")
# add predictions for model checkpoint to the adata
adata.layers["model"] = crested.tl.predict(adata, model_ft).T
adata.layers["model_ft"] = crested.tl.predict(adata, model).T
91002/91002 ━━━━━━━━━━━━━━━━━━━━ 318s 3ms/step
91002/91002 ━━━━━━━━━━━━━━━━━━━━ 318s 3ms/step
If you don’t want to predict on the entire dataset, you can also predict on a given sequence or region by passing coordinates, a DNA sequence, or a one-hot encoded sequence to predict()
.
Many of the plotting functions in the crested.pl
module can be used to visualize these model predictions.
# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df
chr | start | end | split | |
---|---|---|---|---|
region | ||||
chr1:3094031-3096079 | chr1 | 3094031 | 3096079 | test |
chr1:3094696-3096744 | chr1 | 3094696 | 3096744 | test |
chr1:3112760-3114808 | chr1 | 3112760 | 3114808 | test |
chr1:3133812-3135860 | chr1 | 3133812 | 3135860 | test |
chr1:3164934-3166982 | chr1 | 3164934 | 3166982 | test |
... | ... | ... | ... | ... |
chrX:20751180-20753228 | chrX | 20751180 | 20753228 | test |
chrX:21388522-21390570 | chrX | 21388522 | 21390570 | test |
chrX:21392726-21394774 | chrX | 21392726 | 21394774 | test |
chrX:21427413-21429461 | chrX | 21427413 | 21429461 | test |
chrX:21433814-21435862 | chrX | 21433814 | 21435862 | test |
10881 rows × 4 columns
# plot predictions vs ground truth for a random region in the test set defined by index
%matplotlib inline
idx = 22
region = test_df.index[idx]
print(region)
crested.pl.bar.region_predictions(
adata, region, title=f"Predictions vs ground truth ({region})"
)
chr1:3899526-3901574
2025-02-14T13:37:41.075121+0100 INFO Plotting bar plots for region: chr1:3899526-3901574, models: ['model', 'model_ft']

crested.pl.heatmap.correlations_self(
adata,
title="Self-correlation heatmap",
x_label_rotation=90,
width=10,
height=10,
vmin=-1,
vmax=1,
)

crested.pl.heatmap.correlations_predictions(
adata,
split="test",
title="Correlations between ground truths and predictions",
x_label_rotation=90,
width=20,
height=10,
log_transform=False,
# vmin = -1,
# vmax = 1,
)
2025-02-14T13:38:10.984708+0100 INFO Plotting heatmap correlations for split: test, models: ['model', 'model_ft']

crested.pl.scatter.class_density(
adata,
split="test",
log_transform=True,
width=10,
height=5,
)
2025-02-14T13:39:20.714235+0100 INFO Plotting density scatter for all targets and predictions, models: ['model', 'model_ft'], split: test

Besides looking at prediction scores, we can also use these models to explain the features in the sequence that contributed to predicted accessibility in a certain cell type.
Here, we’ll look at three regions, expected to be active in microglia (Micro_PVM
), Sst/Chodl GABAergic neurons (SstChodl
), or in layer 6b glutamatergic neurons (L6b
) respectively.
regions_of_interest = [
"chr18:61107803-61109851",
"chr13:92952218-92954266",
"chr9:56036511-56038559",
]
classes_of_interest = ["Micro_PVM", "SstChodl", "L6b"]
class_idx = list(adata.obs_names.get_indexer(classes_of_interest))
scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
regions_of_interest,
target_idx=class_idx,
model=model_ft,
)
2025-02-14T13:57:00.326922+0100 INFO Calculating contribution scores for 3 class(es) and 3 region(s).
# Plot attribution scores
crested.pl.patterns.contribution_scores(
scores,
one_hot_encoded_sequences,
sequence_labels=regions_of_interest,
class_labels=classes_of_interest,
zoom_n_bases=500,
)
