crested.tl.zoo.utils.MultiheadAttention#
- class crested.tl.zoo.utils.MultiheadAttention(value_size, key_size, heads, scaling=True, attention_dropout_rate=0.0, relative_position_symmetric=False, relative_position_functions='borzoi', relative_position_absolute=False, num_position_features=None, absolute_positions=False, positional_dropout_rate=0.0, zero_initialize=True, initializer='he_normal', l2_scale=0.0, name='mhsa', **kwargs)#
Creates a MultiheadAttention module.
Adapted from Baskerville’s MultiheadAttention, original version written by Ziga Avsec.
- Parameters:
value_size – The size of each value embedding per head.
key_size – The size of each key and query embedding per head.
heads – The number of independent queries per timestep.
scaling (
bool(default:True)) – Whether to scale the attention logits.attention_dropout_rate (
float(default:0.0)) – Dropout rate for attention logits.relative_position_symmetric (
bool(default:False)) – If True, the symmetric version of basis functions will be used. If False, a symmetric and asymmetric versions will be used.relative_position_functions (
str(default:'borzoi')) – Relative position functions to use. Can be ‘enformer’ or ‘borzoi’. Enformer default is ‘enformer’ (exponential & central_mask (scaling factor 2) & gamma). Borzoi default is ‘borzoi’ (central mask only (scaling factor depending on length)).relative_position_absolute (
bool(default:False)) – Whether to use the absolute of values before calculating the relative position encoding.num_position_features (
int|None(default:None)) – Number of relative positional features to compute. If None,value_size * num_headsis used.absolute_positions (
bool(default:False)) – If True, use absolute positional encoding instead of relative positional encodings. Default in Borzoi is False.positional_dropout_rate (
float(default:0.0)) – Dropout rate for the positional encodings if relative positions are used.zero_initialize (
bool(default:True)) – if True, the final linear layer will be 0 initialized.initializer (
str(default:'he_normal')) – Initializer for the projection layers. If unspecified, VarianceScaling is used with scale = 2.0.
Attributes table#
Methods table#
Attributes#
Methods#
- MultiheadAttention.build(input_shape)#
Build layer weights.
- MultiheadAttention.call(inputs, training=False)#
Calculate the multihead attention for a given input.
- MultiheadAttention.get_config()#
Get the config for the multiheadattention layer.