Download presentation
Presentation is loading. Please wait.
1
580.704 Mathematical Foundations of BME Reza Shadmehr
The loss function, the normal equation, cross validation, LMS algorithm, Steepest descent algorithm
2
Review of the linear classification problem
Hypothesis class: we assume that what we are about to approximate is a function that belongs to some space of functions F. We don’t know the true function f, but we hypothesize that it too belongs to F : Hypothesis:
3
Whenever our estimate was wrong, change the weight for each “expert”:
Estimation: we are given a training set of examples and labels and using some adaptation algorithm, we find Trial number Pixel number Whenever our estimate was wrong, change the weight for each “expert”: whenever Evaluation: we measure how well our estimate generalizes to novel examples.
4
We might want to minimize the loss over the training set:
Loss function The loss function provides a cost for being wrong. The objective of adaptation is to minimize the loss function. In the case of image labeling problem, we might have: We might want to minimize the loss over the training set: This is a function of the parameters w and we can minimize it directly. Empirical loss function Find w that minimizes:
5
Issues about minimization of the loss function
Why should be minimize the loss over the training set when we are actually interested in minimizing the loss over the test set? We assume that each training and test example-label pair (x,y) is drawn independently and at random from the same but unknown population of examples and labels. We represent this population as a joint probability distribution p(x,y) so that each training/test example is a sample from this distribution: The training loss based on a few sampled examples and labels serves as a proxy for the test performance measured over the whole population. Issues: It’s quite possible that the sample for training is a poor proxy for the test set. Empirical loss (training set only) Expected loss over all the data (training and test together)
6
– specify the hypothesized class of functions (e.g., linear)
Regression The goal is to make quantitative (real valued) predictions on the basis of a (vector of) features or attributes. Example: years to onset of Huntington’s disease in genetically at risk individuals. Current age CAG repeats Mother HD? Father HD? HD onset 43 37 55 1 51 33 37 1 49 39 43 1 1 We need to – specify the hypothesized class of functions (e.g., linear) – select how to measure prediction loss (the loss function) – solve the resulting minimization problem
7
Regression: Hypothesized class and the loss function
Univariate regression: Multivariate regression: Parameters we need to find: Squared error Loss function: Empirical loss: Mean squared error
8
Regression: estimation of parameters
We have to minimize the empirical loss function: This function is quadratic in terms of w. It has a minimum at some point in the w space. We find the w that minimizes the loss function by finding the conditions where the derivative of the loss function is zero.
9
Optimality conditions: Finding conditions that minimize the empirical loss function
We end up with 2 equations and 2 unknowns and we can solve for it exactly. Chain rule reminder
10
Expected behavior of the model errors (residuals)
Error in the prediction (model residual) The prediction error should be mean zero and not contain any linear trends, i.e., be uncorrelated with any linear function of the inputs. If ytilde(i)=K x(i), then there is a linear trend in the error pattern. If this is the case, then Sum[ytilde(i) x(i)] = Sum[K x(i) x(i)] = K Sum[x(i)^2], which cannot be zero unless x(i) is zero. Therefore, we have proved that if Sum[ytilde(i) x(i)] is zero, then it follows that there can be no linear trend in the residuals. But there may exist some non-linear function of inputs that can account for the residuals. y x x
11
Loss function: matrix notation
“L2” norm
12
Optimality condition: minimize mean squared error
the “normal” equation
13
The pseudo-inverse
14
Review of regression Univariate regression: Multivariate regression: Parameters we need to find: Loss function: Empirical loss:
15
Regression with polynomials
univariate regression with m-th order polynomials:
16
Regression with polynomials: fit improves with increased order
17
Over-fitting We want to fit the training set, but as model complexity increases, we run the risk of over-fitting. Train set Leave out When the model order is over-fitting, leaving a single data point out of the training set can drastically change the fit.
18
Cross validation We want to fit the training set, but we want to also generalize correctly. To measure generalization, we leave out a data point (named the test point), fit the data, and then measure error on the test point. The average error over all possible test points is the cross validation error. Weights estimated from a training set that does not include the i-th data point
19
(actual data was generated with a 2nd order polynomial process)
Cross validation Cross validation error will often increase when the model structure is over-fitting the data. Mean-squared error (training set) Cross-validation error Model order Model order (actual data was generated with a 2nd order polynomial process) Cross-validation error Model order
20
Batch vs. online learning algorithms
In “batch” learning, we don’t have to make any predictions until we see all of the data. At that point, we make a model to fit all the data. In “online” learning, data points are given to us one at a time. We use each example pair to update our model. We are given an x and with our current model we predict a y The teacher tells us our error We modify our model
21
Online learning: the LMS algorithm
Assume we have the model: When we project w onto x, we get a scalar p: What we want is to change w so that when we project onto x we get: Anywhere along the dash line is the solution we’re looking for.
22
The LMS algorithm Unit vector along x w changes along a vector parallel to the input x in that trial with a magnitude proportional to the prediction error in that trial. With this step size, we change w to completely account for the error in that trial. “step size”
23
LMS algorithm attempts to minimize a squared error loss function by approximating the gradient of the loss function Steepest descent algorithm Average error over all data points LMS: local error as a rough estimate of average error
24
Convergence of LMS-algorithm
Iterating over two data points Iterating over three data points Equilibrium point With 3 data points, solution will not move to a single point and stay put. It converges to a small region of the parameter space but will bounce around, as long as h > 0.
25
Convergence of LMS-algorithm
It is difficult to prove “convergence” of LMS because the weights keep bouncing around. But we can prove convergence for the steepest decent algorithm and then use the fact that LMS is a stochastic approximation to it. a geometric series
26
Convergence of a geometric series of scalars
27
Convergence of a geometric series of matrices
See homework for this:
28
Convergence of steepest descent algorithm
See homework for proof of this.
29
We have shown convergence of the steepest decent algorithm to the solution of the normal equations. The LMS is a stochastic approximation to steepest decent, thus it “converges” as well, but will jump around stochastically, as long as the learning rate is greater than zero. Convergence can be reached when the learning rate is systematically made smaller on each step. We will call changes of the learning rate “adaptive learning” and will see a principled approach to this problem when we consider Bayesian approaches to learning.
30
Summary: Linear Regression
Univariate regression: Multivariate regression: Parameters we need to find: Loss function: Empirical loss:
31
Summary: Iterative learning
Increased model complexity reduces error over the training data but can increase the leave-one-out cross validation error. We want a model that fits the trained data and generalizes correctly. LMS algorithm: w changes along a vector parallel to the input x in that trial with a magnitude proportional to the error in that trial. Steepest descent algorithm:
Similar presentations
© 2024 SlidePlayer.com Inc.
All rights reserved.