crested.tl.zoo.utils.MultiheadAttention#

class crested.tl.zoo.utils.MultiheadAttention(value_size, key_size, heads, scaling=True, attention_dropout_rate=0.0, relative_position_symmetric=False, relative_position_functions='borzoi', relative_position_absolute=False, num_position_features=None, absolute_positions=False, positional_dropout_rate=0.0, zero_initialize=True, initializer='he_normal', l2_scale=0.0, name='mhsa', **kwargs)#

Creates a MultiheadAttention module.

Adapted from Baskerville’s MultiheadAttention, original version written by Ziga Avsec.

Parameters:
  • value_size – The size of each value embedding per head.

  • key_size – The size of each key and query embedding per head.

  • heads – The number of independent queries per timestep.

  • scaling (bool (default: True)) – Whether to scale the attention logits.

  • attention_dropout_rate (float (default: 0.0)) – Dropout rate for attention logits.

  • relative_position_symmetric (bool (default: False)) – If True, the symmetric version of basis functions will be used. If False, a symmetric and asymmetric versions will be used.

  • relative_position_functions (str (default: 'borzoi')) – Relative position functions to use. Can be ‘enformer’ or ‘borzoi’. Enformer default is ‘enformer’ (exponential & central_mask (scaling factor 2) & gamma). Borzoi default is ‘borzoi’ (central mask only (scaling factor depending on length)).

  • relative_position_absolute (bool (default: False)) – Whether to use the absolute of values before calculating the relative position encoding.

  • num_position_features (int | None (default: None)) – Number of relative positional features to compute. If None, value_size * num_heads is used.

  • absolute_positions (bool (default: False)) – If True, use absolute positional encoding instead of relative positional encodings. Default in Borzoi is False.

  • positional_dropout_rate (float (default: 0.0)) – Dropout rate for the positional encodings if relative positions are used.

  • zero_initialize (bool (default: True)) – if True, the final linear layer will be 0 initialized.

  • initializer (str (default: 'he_normal')) – Initializer for the projection layers. If unspecified, VarianceScaling is used with scale = 2.0.

Attributes table#

Methods table#

build(input_shape)

Build layer weights.

call(inputs[, training])

Calculate the multihead attention for a given input.

get_config()

Get the config for the multiheadattention layer.

Attributes#

Methods#

MultiheadAttention.build(input_shape)#

Build layer weights.

MultiheadAttention.call(inputs, training=False)#

Calculate the multihead attention for a given input.

MultiheadAttention.get_config()#

Get the config for the multiheadattention layer.