Improvements#
There are many ways we could improve the performance of our network, but these will add a lot of complexity to the simple class that we wrote. Fortunately there are a lot of machine learning libraries that provide these features, and work efficiently, so for real applications we would want to use one of those libraries (we’ll explore these next).
Batching#
Right now, we did our training as:
Loop over the
pairs forPropagate
through the networkCompute the corrections
,Update the matrices:
In this manner, each training pair sees slightly different
matrices
We could instead divide our training set into
Loop over
batchesLoop over the
pairs for in the current batchPropagate
through the networkCompute the gradients
, from the current pairAccumulate the gradients:
Apply a single update to the matrices for this batch:
Note
We normalize the accumulated gradients by the batch size,
The advantage of this is that the
Different activation or cost functions#
We used a simple cost function: the sum of the square of the errors. This is analogous to the
Likewise, there are a wide number of activation functions, some of which are not differentiable. The choice of activation function can depend on what type of data you are using. You might also want to use a different activation function on each layer. Again, this would require us to redo our derivatives.
Use a different minimization technique#
We only explored gradient descent. But there are improvements to this (like momentum that we mentioned previously) as well as alternate minimization techniques we could use (some of which don’t need the gradient at all).
Different types of layers / connections#
We only considered a dense network: every node on one layer was connected to every node on the adjacent layer. But there are alternatives.
For example, a convolutional neural network performs a convolution on a layer with some kernel. This helps identifying features.
Auto-differentiation libraries#
At some point, with all of these options, doing all of the differentiation / chain-rule by hand becomes burdensome and prone to errors. For this reason, libraries often use automatic differentiation libraries, like JAX which can take the derivatives of our python functions themselves.