No comment yet

For many machine learning practitioners, training loop is a universally agreed upon concept as demonstrated by numerous documentations, conference papers to use the word without any reference. It would be a helpful concept for many beginners to get familiar with before diving into the rabbit holes of many deep learning tools.

The field of machine learning becomes vast, diverse and ever-more open in the past 10 years. We have all kinds of open-source softwares, from XGBoost, LightGBM, Theano to TensorFlow, Keras and PyTorch to simplify various tasks of machine learning. We have supervised learning, unsupervised learning, generative network and reinforcement learning, choose your own pill. It can be dazzling for beginners. It doesn’t help that many popular softwares we use made simplifications to hide many details from beginners with abstractions like classes, functions and callbacks.

But fear not, for machine learning beginners, there is one universal template. Once you understand it, it is straightforward to fit all existing training programs into this template and start to dig into how they implemented the details. I call this the universal training loop.

An Extremely Simplified Interface

Many high-level framework may provide an extremely simplified interface looking like this:

1
func train(training_dataset: Dataset, validation_dataset: Dataset) -> Classifier

When you do:

1
let trainedClassifier = Classifier.train(training_dataset: training_dataset, validation_dataset: validation_dataset)

You somehow get the trained classifier from that interface. This is what you would find in FastAI’s Quick Start page, or Apple’s Create ML framework.

However, this doesn’t tell you much about what it does. It is also not helpful some of these frameworks provided callbacks or hooks into the training loop at various stages. The natural question would be: what are the stages?

An Supervised Learning Training Loop

It is actually not hard to imagine what a supervised learning training loop would look like underneath the extremely simplified interface. It may look like this:

1
2
3
4
var classifier = Classifier()
for example in training_dataset {
  classifier.fit(input: example.data, target: example.target)
}

It tries to go through all examples in the training dataset and try to fit them one-by-one.

For stochastic gradient descent methods, we had a few modifications to make the training process more stable and less prone to input orders:

1
2
3
4
var classifier = Classifier()
for minibatch in training_dataset.shuffled().grouped(by: 32) {
  classifier.fit(inputs: minibatch.data, targets: minibatch.targets)
}

We randomizes the order of input (shuffled()), group them into mini-batches, and pass them into the classifier, assuming the classifier operates with a group of examples directly.

For many different types of neural networks, shuffled mini-batches will be the essential part of your training loop for both efficiency and stability reasons.

4 Steps to Fit

The magical fit function doesn’t inspire any deeper understanding of what’s going on. It looks like we simply lift the train function into a for-loop.

However, if we can assume that our machine learning model is differentiable and based on gradient methods (e.g. neural networks), we can break down the fit function into 4 steps.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
var classifier = Classifier()
for minibatch in training_dataset.shuffled().grouped(by: 32) {
  // Given inputs, a machine learning model can guess what the
  // outputs would be. (Labels of the images, positions of the faces
  // or translated texts from the original.)
  let guesses = classifier.apply(inputs: minibatch.data)
  // Loss measures how far away our guesses compares to the targets
  // we knew from the training data. This is supervised learning, we
  // know the answer already.
  let loss = classifier.loss(guesses: guesses, targets: minibatch.targets)
  // Based on the loss, gradients give us the direction and magnitude
  // to update our model parameters.
  let gradients = loss.gradients()
  // Update the parameters with gradients from this mini-batch.
  // Optimizer specifies a particular algorithm we use to update
  // parameters, such as stochastic gradient descent or ADAM.
  optimizer.apply_gradients(gradients, classifier.parameters)
}

For any supervised learning, you will be able to find this 4 steps. It can vary, some of the model may accumulate gradients a bit and then apply_gradients after several rounds. Some of them may apply additional clipping on the gradients before applying them.

You could find the exact 4 steps in frameworks like Keras or PyTorch.

Validation Dataset and Epoch

We haven’t talked about the validation_dataset parameter you saw earlier for the train method!

For first-order gradients based methods (e.g. neural networks), we need to go over the whole training dataset multiple times to reach the local minima (a reasonable model). When we went over the whole training dataset once, we call it one epoch. Our models can also suffer the overfitting problem. Validation dataset are the data the model never uses when updating its parameters. It is useful for us to understand what our model would be like to data it never sees.

To incorporate the above two insights, our training loop can be further modified to:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
var classifier = Classifier()
for epoch in 0..<max_epoch {
  for minibatch in training_dataset.shuffled().grouped(by: 32) {
    let guesses = classifier.apply(inputs: minibatch.data)
    let loss = classifier.loss(guesses: guesses, targets: minibatch.targets)
    let gradients = loss.gradients()
    optimizer.apply_gradients(gradients, classifier.parameters)
  }
  var stats = Stats()
  for example in validation_dataset {
    // Only gather guesses, never update the parameters.
    let guess = classifier.apply(input: example.data)
    // Stats will compare guess to the target, and return some
    // helpful statistics.
    stats.accumulate(guess: guess, target: target)
  }
  print("Epoch \(epoch), validation dataset stats: \(stats)")
}

Now I can claim, for any supervised learning task, you can find the above training loop when you dig deeper enough through their abstractions. We can call this the universal supervised training loop.

Unsupervised Learning and Generative Networks

The main difference between unsupervised learning and supervised learning for our training loop is that we won’t have the target provided from the training dataset. We derive the target somewhere else. In unsupervised learning, we derive the target from some transformations of the input. In generative networks, we derive the target from random noises (hence generating something from nothing).

1
2
3
4
5
6
7
8
9
10
11
12
13
var model = Model()
for epoch in 0..<max_epoch {
  for minibatch in training_dataset.shuffled().grouped(by: 32) {
    let guesses = model.apply(inputs: minibatch.data)
    // Unsupervised learning.
    let targets = model.targets(from: minibatch.data)
    // Generative networks
    // let targets = model.targets(from: noise)
    let loss = model.loss(guesses: guesses, targets: targets)
    let gradients = loss.gradients()
    optimizer.apply_gradients(gradients, model.parameters)
  }
}

Often times, for this types of tasks, the targets are derived from another set of neural networks and updated jointly. Because of that, there are more whistles and bells in many frameworks when they implement the above training loop. You can find example from Keras on how they derive targets from the input data only (get_masked_input_and_labels) for BERT (a popular unsupervised natural language processing model). Or you can find example from PyTorch how they generate adversarial examples from noises for DCGAN (deep convolutional generative adversarial network).

Deep Reinforcement Learning

Deep reinforcement learning generates the training data by having an agent interacts with the environment. It has its own loop that looks like this:

1
2
3
4
5
6
7
8
while true {
  let action = agent.respond(to: lastObservation)
  let (observation, reward, done) = environment.interact(action: action)
  lastObservation = observation
  if done {
    break
  }
}

The agent took action against our last observation. The environment will be interacted with the action and give a new set of observation.

This is independent from our training loop. In contrast to the supervised learning, in the deep reinforcement learning, we use the interaction loop to generate training data.

Our training loop can be modified to look like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
var policy = Policy()
var training_dataset = Dataset()
for epoch in 0..<max_epoch {
  var data_in_episode = Dataset()
  while true {
    let action = policy(inputs: lastObservation)
    let (observation, reward, done) = environment.interact(action: action)
    data_in_episode.append((action: action, reward: reward, observation: lastObservation))
    lastObservation = observation
    if done {
      for (i, data) in data_in_episode.enumerated() {
        // Use all future rewards to compute our target.
        let target = target_from_future_rewards(data_in_episode[i..<])
        // Our input will be the last observation (Q-learning), and
        // potentially also include the action (Actor-Critic model),
        // or the next observation (model-based methods).
        training_dataset.append((input: (data.action, data.observation), target: target))
      }
      break
    }
  }
  // Rather than shuffling the whole training dataset, we just random
  // sample a subset.
  for minibatch in training_dataset.randomSampled().grouped(by: 32) {
    let guesses = policy.apply(inputs: minibatch.data)
    let loss = policy.loss(guesses: guesses, targets: minibatch.targets)
    let gradients = loss.gradients()
    optimizer.apply_gradients(gradients, policy.parameters)
  }
}

The training_dataset in above training loop can be referred to as the replay memory in the literature. If we retain all the old data before training, this is often referred to as off-policy training. If instead we remove all training data after each episode, this can be called on-policy training. OpenAI Spinning Up has better explanation on the differences between on-policy and off-policy with a bit more details.

You can find examples of the above training loops in PyTorch or in OpenAI Baselines.

Distributed Learning

Follow the same training loop, we can extend them to train on multiple machines:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
var classifier = Classifier()
let machineID = MPI_Comm_rank()
for epoch in 0..<max_epoch {
  for minibatch in training_dataset.shuffled().grouped(by: 32, on: machineID) {
    let guesses = classifier.apply(inputs: minibatch.data)
    let loss = classifier.loss(guesses: guesses, targets: minibatch.targets)
    // Compute gradients from the specific data on this machine only.
    let machineGradients = loss.gradients()
    // Use allreduce primitive to compute gradients summed from all
    // machines.
    let allGradients = allreduce(op: +, value: machineGradients, on: machineID)
    // Applying updates with the same gradients, it should yield the
    // same parameters on all machines.
    optimizer.apply_gradients(allGradients, classifier.parameters)
  }
}

The allreduce primitive go over all machines to sum gradients from them. In reality, it is often implemented with ring-based communication pattern to optimize the throughput. Naively, it can look like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func allreduce(op: Op, value: Tensor, on: Int) -> Tensor {
  MPI_Send(value, to: 0)
  if on == 0 {
    tensors[0] = value
    for i in 1..<MPI_Comm_size() {
      tensors[i] = MPI_Recv(from: i)
    }
    let sum = tensors.sum()
    for i in 1..<MPI_Comm_size() {
       MPI_Send(sum, to: i)
    }
    return sum
  } else {
    return MPI_Recv(from: 0)
  }
}

This naive data-distributed training loop can be extended to more sophisticated distributed training regime. For example, in ZeRO-Offload, its distributed strategy can be represented as:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
var classifier = Classifier()
let machineID = MPI_Comm_rank()
for epoch in 0..<max_epoch {
  for minibatch in training_dataset.shuffled().grouped(by: 32, on: machineID) {
    let guesses = classifier.apply(inputs: minibatch.data)
    let loss = classifier.loss(guesses: guesses, targets: minibatch.targets)
    let gradients = loss.gradients()
    for (i, gradient) in gradients.enumerated() {
      // Each machine only sum the gradients it responsible for. This
      // method will return nil if it tries to reduce a gradient it
      // is not responsible for.
      if let reducedGradient = reduce(op: +, id: i, value: gradient, on: machineID) {
        // Copy the summed gradient to CPU.
        cpuGradients[machineID, i] = reducedGradient.copy(to: .CPU)
      }
    }
    // Apply gradients to the model from CPU.
    optimizer.apply_gradients(cpuGradients[machineID], classifier.parameters[machineID])
    // Broadcast the latest parameters to all machines.
    broadcast(classifier.parameters[machineID])
  }
}

The Universal Training Loop

Finally, we arrived at our universal training loop template:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
var model = Model()
var training_dataset = Dataset()
for epoch in 0..<max_epoch {
  // Collect training dataset either from agent-environment
  // interaction, or from the disk.
  training_dataset = collect(...)
  // Go over mini-batch either on the whole training dataset, or from
  // a subset of it.
  for minibatch in training_dataset.extract(...) {
    // Apply the model to generate some guesses.
    let guesses = model.apply(inputs: minibatch.data)
    // Generate targets either from inputs, from noise, or it already
    // exists in the training dataset. Or a combination of above.
    let targets = targets_from(...)
    // Compute the loss from the model's guess w.r.t. to the target.
    let loss = model.loss(guesses: guesses, targets: targets)
    // First-order gradients from the loss.
    let gradients = loss.gradients()
    // Sum gradients from everywhere (other machines, other GPUs) for
    // this particular node to process.
    if let (i, collected) = gradients.collect(...) {
      optimizer.apply_gradients(collected, model.parameters[i])
      // Broadcast the updated parameters to everywhere.
      broadcast(model.parameters[i])
    }
  }
}

*: The pseudo-code in this article uses Swift. A particular syntax concerns with a[x..<y]. It is semantically the same as a[x:y] in Python, including for cases where some of the subscript is missing: a[x..<] should be the same as a[x:].

blog comments powered by Disqus