crested.tl.losses.CosineMSELogLoss#

class crested.tl.losses.CosineMSELogLoss(max_weight=1.0, name='CosineMSELogLoss', reduction='sum_over_batch_size', multiplier=1000)#

Custom loss function combining logarithmic transformation, cosine similarity, and mean squared error (MSE).

This loss function applies a logarithmic transformation to predictions and true values, normalizes these values, and computes both MSE and cosine similarity. A dynamic weight based on the MSE is used to balance these two components.

Parameters:
  • max_weight (float (default: 1.0)) – The maximum weight applied to the cosine similarity loss component. Lower values will emphasize the MSE component, while higher values will emphasize the cosine similarity component.

  • name (str | None (default: 'CosineMSELogLoss')) – Name of the loss function.

  • reduction (str (default: 'sum_over_batch_size')) – Type of reduction to apply to loss.

  • multiplier (float (default: 1000)) – Scalar to multiply the predicted value with. When predicting mean coverage, multiply by 1000 to get actual count. Keep to 1 when predicting insertion counts.

Notes

  • The log transformation is log(1 + 1000 * y) for positive values and -log(1 + abs(1000 * y)) for negative values.

  • The cosine similarity is computed between L2-normalized true and predicted values.

  • The dynamic weight for the cosine similarity component is constrained between 1.0 and max_weight.

Examples

>>> loss = CosineMSELogLoss(max_weight=2.0)
>>> y_true = np.array([1.0, 0.0, -1.0])
>>> y_pred = np.array([1.2, -0.1, -0.9])
>>> loss(y_true, y_pred)

Attributes table#

Methods table#

call(y_true, y_pred)

Compute the loss value.

from_config(config)

Create a loss function from the configuration.

get_config()

Return the configuration of the loss function.

Attributes#

CosineMSELogLoss.dtype#

Methods#

CosineMSELogLoss.call(y_true, y_pred)#

Compute the loss value.

classmethod CosineMSELogLoss.from_config(config)#

Create a loss function from the configuration.

CosineMSELogLoss.get_config()#

Return the configuration of the loss function.