fit()

The olympic.fit function is the heart of this library and where all the good stuff happens. The aim of this function is to avoid “hand-rolling” your own training loops and hence present a much cleaner interface like Keras or Scikit-learn.

The pseudocode for fit is very simple.:

def fit(model, optimiser, loss_fn, epochs, dataloader, callbacks, update_fn, update_fn_kwargs):

    callbacks.on_train_begin()

    for epoch in range(1, epochs+1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = dict()
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index)

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)

            loss, y_pred = update_fn(model, optimiser, loss_fn, x, y, epoch, **update_fn_kwargs)
            batch_logs['loss'] = loss.item()

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)

            callbacks.on_batch_end(batch_index, batch_logs)

        callbacks.on_epoch_end(epoch, epoch_logs)

    callbacks.on_train_end()

The default update_fn is just a regular gradient descent step (see below) but any callable with the right signature can be passed. Alternate ``update_fn``s could be more involved such as adversarial training or the Model-Agnostic Meta-Learning algorithm. For an example see fit/usage.