crested.utils.permute_model

Contents

crested.utils.permute_model#

crested.utils.permute_model(model, new_input_shape)#

Add a permutation layer to the input of a model to change the shape from (B, W, C) to (B, C, W) or vice versa.

Useful to convert from tensorflow consenus format to torch (e.g. to use with tangermeme).

Parameters:
  • model (Model) – The keras model to add the permutation layer to.

  • new_input_shape (tuple[int, int]) – The new input shape to the model (e.g. (4, 500))

Return type:

Model

Returns:

The new model with the permutation layer added to the input.

Example

>>> inputs = keras.layers.Input(shape=(4, 500))
>>> model = keras.models.Model(inputs=inputs, outputs=inputs)
>>> new_model = crested.utils.permute_model(model, (500, 4))