← Back

Why Machine Learning Models Learn

When I started working on AI projects in early 2023, I thought AI was simply understanding PyTorch syntax and experimenting with different hyperparameters until I got a high enough accuracy score. Looking back, this view was (embarrassingly) naïve, but it did get me very excited about the potential of and applications for AI systems.

Since then, I have been on a personal journey to learn as much as I possibly can and be able to reason about AI from deep math- and physics-informed principles. A week ago, I got a book covering all the math behind virtually every ML model; reading a few select passages has already paid off in a significant way as I feel like I understand the mathematical framework behind machine learning on a much deeper and clearer level.

Distributions as Beliefs

At their core, ML problems decompose to optimizing conditional probability distributions, which in Bayesian statistics, encode beliefs about hypotheses. Suppose you want to know the weather tomorrow; for simplicity, the weather can either be sunny, rainy, or cloudy (independent events).

A naïve approach would be to train a model to predict a point estimate of the weather: "Tomorrow, because of my observation, my model predicts that it will rain tomorrow." This claim is flawed because the model might be wrong, and it doesn't output uncertainty. A much better approach would be estimating a conditional probability distribution because this method considers uncertainty: "Because of my observation, I am 65% certain that it will rain tomorrow, 30% certain it will be cloudy, and 5% sure it will be sunny."

The objective transitions from, "Based on the data, which weather event will most likely occur?" to, "How does receiving more information update my belief system about the weather?"

How Models Learn

As we train the model on our dataset, our goal is to optimize parameters that model the underlying distribution of our data. According to Bayes' Theorem, our new belief is proportional to how well this belief explains the data weighted by the strength of our prior belief. Bayes' gives a theoretical framework that explains why probabilistic models learn, while the KL Divergence can be leveraged to compute how they learn.

In each training step, we use the KL Divergence, which quantifies the cost of holding the wrong belief, to compute the scalar error. Furthermore, we minimize the KL Divergence by computing its gradient with respect to the model's parameters and updating them via gradient descent.

Calibration Over Confidence

Because of seeing more data, the entropy, which quantifies uncertainty in a distribution, tends to decrease as informative evidence accumulates. While a reduction in uncertainty is a natural consequence of considering more information, the calibration is more important. If the model is confidently wrong, then the model is not useful. Instead, if the model after further training converges to be 95% confident of rain, 4% confident of clouds, and 1% confident of sunny weather, then it should follow that 95% of all real outcomes conditioned on this observation should be rain.