Topic Classification#

We can use the outputs of pycistopic to train a model to predict topic probabilities for a given sequence.

Since we plan on adding detailed use cases describing topic classification later on, we will only provide a brief overview of the workflow here. Refer to the introductory notebook for a more detailed explanation of the CREsted workflow.

Import Data#

For this tutorial, we will use the Mouse BICCN dataset. We will use the preprocessed, binarized outputs of pycistopic as input data for the topic classification model.

To train a topic classification model, we need the following data:

  1. A folder containing BED files per Topic (output of pycistopic).

  2. A genome fasta and optionally a chromosome sizes file.

import crested
# Download the tutorial data
import os

os.environ[
    "CRESTED_DATA_DIR"
] = "../../../Crested_testing/data/tmp"  # Change this to your desired directory
beds_folder, regions_file = crested.get_dataset("mouse_cortex_bed")

We can import a folder of BED files using the crested.import_beds() function.
This will return an Anndata object with the regions as .var and the bed file names as .obs (here: our Topics).
In this case, the adata.X values are binary, representing whether that region is associated with a topic or not.

# Import the beds into an AnnData object
adata = crested.import_beds(
    beds_folder=beds_folder, regions_file=regions_file
)  # the regions file is optional for import_beds
adata
2024-08-14T11:45:13.940704+0200 WARNING Chromsizes file not provided. Will not check if regions are within chromosomes
2024-08-14T11:45:14.496482+0200 INFO Reading bed files from /lustre1/project/stg_00002/lcb/lmahieu/projects/Crested_testing/data/tmp/data/mouse_biccn/beds.tar.gz.untar and using /lustre1/project/stg_00002/lcb/lmahieu/projects/Crested_testing/data/tmp/data/mouse_biccn/consensus_peaks_biccn.bed as var_names...
2024-08-14T11:47:31.431482+0200 WARNING 107610 consensus regions are not open in any class. Removing them from the AnnData object. Disable this behavior by setting 'remove_empty_regions=False'
View of AnnData object with n_obs × n_vars = 80 × 439383
    obs: 'file_path', 'n_open_regions'
    var: 'n_classes', 'chr', 'start', 'end'

We have 80 classes (topics) and 439386 regions in the dataset.

Preprocessing#

For topic classification there is little preprocessing to be performed compared to peak regression.
The data does not need to be normalized since the values are binary and we don’t filter any regions on specificity since by nature of topic modelling the selected regions should already be ‘meaningful’ regions.
You could change the width of the regions, but we tend to keep the regions at 500bp for topic classification.

The only preprocessing step we need to perform is to split the data into training and testing sets.

# Standard train/val/test split
crested.pp.train_val_test_split(
    adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
)
print(adata.var["split"].value_counts())
split
train    354013
val       45113
test      40257
Name: count, dtype: int64

Model Training#

Model training has the same workflow as peak regression. The only differences are:

  1. We select a different model architecture. Since we’re training on 500bp regions we don’t need the dilated convolutions of chrombpnet.

  2. We select a different config, since we’re monitoring other metrics and are using a different loss for classification.

# Datamodule
datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome="../../../Crested_testing/data/tmp/mm10.fa",
    batch_size=128,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

# Architecture: we will use the DeepTopic CNN model
model_architecture = crested.tl.zoo.deeptopic_cnn(seq_len=500, num_classes=80)

# Config: we will use the default topic classification config (binary cross entropy loss and AUC/ROC metrics)
config = crested.tl.default_configs("topic_classification")
print(config)
2024-08-14T11:47:53.209678+0200 WARNING Chromsizes file not provided when shifting. Will not check if shifted regions are within chromosomes
TaskConfig(optimizer=<keras.src.backend.torch.optimizers.torch_adam.Adam object at 0x145af25458e0>, loss=<keras.src.losses.losses.BinaryCrossentropy object at 0x145ad99a18b0>, metrics=[<AUC name=auROC>, <AUC name=auPR>, <CategoricalAccuracy name=categorical_accuracy>])

Set up the Trainer

# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="mouse_biccn_topics",  # change to your liking
    logger=None,  # or 'wandb', 'tensorboard'
)
# train the model
trainer.fit(epochs=100)
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ sequence            │ (None, 500, 4)    │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d (Conv1D)     │ (None, 500, 1024) │     69,632 │ sequence[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalization │ (None, 500, 1024) │      4,096 │ conv1d[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation          │ (None, 500, 1024) │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling1d       │ (None, 125, 1024) │          0 │ activation[0][0]  │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout (Dropout)   │ (None, 125, 1024) │          0 │ max_pooling1d[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d_1 (Conv1D)   │ (None, 125, 512)  │  5,767,168 │ dropout[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 125, 512)  │      2,048 │ conv1d_1[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation_1        │ (None, 125, 512)  │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling1d_1     │ (None, 32, 512)   │          0 │ activation_1[0][ │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_1 (Dropout) │ (None, 32, 512)   │          0 │ max_pooling1d_1[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d_2 (Conv1D)   │ (None, 32, 512)   │  2,883,584 │ dropout_1[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 32, 512)   │      2,048 │ conv1d_2[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation_2        │ (None, 32, 512)   │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling1d_2     │ (None, 8, 512)    │          0 │ activation_2[0][ │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_2 (Dropout) │ (None, 8, 512)    │          0 │ max_pooling1d_2[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d_3 (Conv1D)   │ (None, 8, 512)    │  1,310,720 │ dropout_2[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 512)    │      2,048 │ conv1d_3[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation_3        │ (None, 8, 512)    │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 8, 512)    │          0 │ activation_3[0][ │
│                     │                   │            │ dropout_2[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling1d_3     │ (None, 2, 512)    │          0 │ add[0][0]         │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_3 (Dropout) │ (None, 2, 512)    │          0 │ max_pooling1d_3[ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1d_4 (Conv1D)   │ (None, 2, 512)    │    524,288 │ dropout_3[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 2, 512)    │      2,048 │ conv1d_4[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation_4        │ (None, 2, 512)    │          0 │ batch_normalizat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 2, 512)    │          0 │ activation_4[0][ │
│                     │                   │            │ dropout_3[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten (Flatten)   │ (None, 1024)      │          0 │ add_1[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_4 (Dropout) │ (None, 1024)      │          0 │ flatten[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ denseblock_dense    │ (None, 1024)      │  1,048,576 │ dropout_4[0][0]   │
│ (Dense)             │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ denseblock_batchno… │ (None, 1024)      │      4,096 │ denseblock_dense… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ denseblock_activat… │ (None, 1024)      │          0 │ denseblock_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ denseblock_dropout  │ (None, 1024)      │          0 │ denseblock_activ… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 80)        │     82,000 │ denseblock_dropo… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ activation_5        │ (None, 80)        │          0 │ dense[0][0]       │
│ (Activation)        │                   │            │                   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 11,702,352 (44.64 MB)
 Trainable params: 11,694,160 (44.61 MB)
 Non-trainable params: 8,192 (32.00 KB)
None
2024-08-14T11:48:17.282138+0200 INFO Loading sequences into memory...
2024-08-14T11:48:27.847601+0200 INFO Loading sequences into memory...
Epoch 1/100
  53/5532 ━━━━━━━━━━━━━━━━━━━━ 12:06 133ms/step - auPR: 0.0385 - auROC: 0.5086 - categorical_accuracy: 0.0105 - loss: 0.60022024-08-14T11:49:25.448633+0200 WARNING Training interrupted by user.

Evaluation and Prediction#

Evaluation and prediction are the same as peak regression.

The next steps you could take are to:

  1. Evaluate the model on the test set.

  2. Predict topic probabilities for a given sequence or region.

  3. Run tfmodisco to find motifs associated with each topic.

  4. Generate synthetic sequences for each topic using in silico evolution.

  5. Plot contribution scores per topic for interesting regions or sequences.

Refer to the introduction notebook for more details.