Statistical Learning

Biostat 203B

Author

Dr. Hua Zhou @ UCLA

Published

February 23, 2023

In the next few lectures, we focus on the modeling step. We take the narrow sense of statistical/machine learning for the model step.

1 Overview of statistical/machine learning

In this class, we use the phrases statistical learning, machine learning, or simply learning interchangeably.

1.1 Supervised vs unsupervised learning

  • Supervised learning: input(s) -> output.
    • Prediction: the output is continuous (income, weight, bmi, …).
    • Classification: the output is categorical (disease or not, pattern recognition, …).
  • Unsupervised learning: no output. We learn relationships and structure in the data.
    • Clustering.
    • Dimension reduction.

1.2 Supervised learning

  • Predictors \[ X = \begin{pmatrix} X_1 \\ \vdots \\ X_p \end{pmatrix}. \] Also called inputs, covariates, regressors, features, independent variables.

  • Outcome \(Y\) (also called output, response variable, dependent variable, target).

    • In the regression problem, \(Y\) is quantitative (price, weight, bmi).
    • In the classification problem, \(Y\) is categorical. That is \(Y\) takes values in a finite, unordered set (survived/died, customer buy product or not, digit 0-9, object in image, cancer class of tissue sample).
  • We have training data \((\mathbf{x}_1, y_1), \ldots, (\mathbf{x}_n, y_n)\). These are observations (also called samples, instances, cases). Training data is often represented by a predictor matrix \[ \mathbf{X} = \begin{pmatrix} x_{11} & \cdots & x_{1p} \\ \vdots & \ddots & \vdots \\ x_{n1} & \cdots & x_{np} \end{pmatrix} = \begin{pmatrix} \mathbf{x}_1^T \\ \vdots \\ \mathbf{x}_n^T \end{pmatrix} \tag{1}\] and a response vector \[ \mathbf{y} = \begin{pmatrix} y_1 \\ \vdots \\ y_n \end{pmatrix} \]

  • Based on the training data, our goal is to

    • Accurately predict unseen outcome of test cases based on their predictors.
    • Understand which predictors affect the outcome, and how.
    • Assess the quality of our predictions and inferences.

1.2.1 Example: salary

  • The Wage data set collects the wage and other data for a group of 3000 male workers in the Mid-Atlantic region in 2003-2009.

  • Our goal is to establish the relationship between salary and demographic variables in population survey data.

  • Since wage is a quantitative variable, it is a regression problem.

library(gtsummary)
library(ISLR2)
library(tidyverse)

# Convert to tibble
Wage <- as_tibble(Wage) %>% print(width = Inf)
# A tibble: 3,000 × 11
    year   age maritl           race     education       region            
   <int> <int> <fct>            <fct>    <fct>           <fct>             
 1  2006    18 1. Never Married 1. White 1. < HS Grad    2. Middle Atlantic
 2  2004    24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic
 3  2003    45 2. Married       1. White 3. Some College 2. Middle Atlantic
 4  2003    43 2. Married       3. Asian 4. College Grad 2. Middle Atlantic
 5  2005    50 4. Divorced      1. White 2. HS Grad      2. Middle Atlantic
 6  2008    54 2. Married       1. White 4. College Grad 2. Middle Atlantic
 7  2009    44 2. Married       4. Other 3. Some College 2. Middle Atlantic
 8  2008    30 1. Never Married 3. Asian 3. Some College 2. Middle Atlantic
 9  2006    41 1. Never Married 2. Black 3. Some College 2. Middle Atlantic
10  2004    52 2. Married       1. White 2. HS Grad      2. Middle Atlantic
   jobclass       health         health_ins logwage  wage
   <fct>          <fct>          <fct>        <dbl> <dbl>
 1 1. Industrial  1. <=Good      2. No         4.32  75.0
 2 2. Information 2. >=Very Good 2. No         4.26  70.5
 3 1. Industrial  1. <=Good      1. Yes        4.88 131. 
 4 2. Information 2. >=Very Good 1. Yes        5.04 155. 
 5 2. Information 1. <=Good      1. Yes        4.32  75.0
 6 2. Information 2. >=Very Good 1. Yes        4.85 127. 
 7 1. Industrial  2. >=Very Good 1. Yes        5.13 170. 
 8 2. Information 1. <=Good      1. Yes        4.72 112. 
 9 2. Information 2. >=Very Good 1. Yes        4.78 119. 
10 2. Information 2. >=Very Good 1. Yes        4.86 129. 
# … with 2,990 more rows
# Summary statistics
Wage %>% tbl_summary()
Characteristic N = 3,0001
year
    2003 513 (17%)
    2004 485 (16%)
    2005 447 (15%)
    2006 392 (13%)
    2007 386 (13%)
    2008 388 (13%)
    2009 389 (13%)
age 42 (34, 51)
maritl
    1. Never Married 648 (22%)
    2. Married 2,074 (69%)
    3. Widowed 19 (0.6%)
    4. Divorced 204 (6.8%)
    5. Separated 55 (1.8%)
race
    1. White 2,480 (83%)
    2. Black 293 (9.8%)
    3. Asian 190 (6.3%)
    4. Other 37 (1.2%)
education
    1. < HS Grad 268 (8.9%)
    2. HS Grad 971 (32%)
    3. Some College 650 (22%)
    4. College Grad 685 (23%)
    5. Advanced Degree 426 (14%)
region
    1. New England 0 (0%)
    2. Middle Atlantic 3,000 (100%)
    3. East North Central 0 (0%)
    4. West North Central 0 (0%)
    5. South Atlantic 0 (0%)
    6. East South Central 0 (0%)
    7. West South Central 0 (0%)
    8. Mountain 0 (0%)
    9. Pacific 0 (0%)
jobclass
    1. Industrial 1,544 (51%)
    2. Information 1,456 (49%)
health
    1. <=Good 858 (29%)
    2. >=Very Good 2,142 (71%)
health_ins
    1. Yes 2,083 (69%)
    2. No 917 (31%)
logwage 4.65 (4.45, 4.86)
wage 105 (85, 129)
1 n (%); Median (IQR)
# Plot wage ~ age
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point() + 
  geom_smooth() +
  labs(title = "Wage changes nonlinearly with age",
       x = "Age",
       y = "Wage (k$)")

# Plot wage ~ year
Wage %>%
  ggplot(mapping = aes(x = year, y = wage)) + 
  geom_point() + 
  geom_smooth(method = "lm") +
  labs(title = "Average wage increases by $10k in 2003-2009",
       x = "Year",
       y = "Wage (k$)")

# Plot wage ~ education
Wage %>%
  ggplot(mapping = aes(x = education, y = wage)) + 
  geom_point() + 
  geom_boxplot() +
  labs(title = "Wage increases with education level",
       x = "Year",
       y = "Wage (k$)")

Summary statistics:

# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font size in plots
sns.set(font_scale = 2)
# Display all columns
pd.set_option('display.max_columns', None)

# Import Wage data
Wage = pd.read_csv(
  "../data/Wage.csv",
  dtype =  {
    'maritl':'category', 
    'race':'category',
    'education':'category',
    'region':'category',
    'jobclass':'category',
    'health':'category',
    'health_ins':'category'
    }
  )
Wage
      year  age            maritl      race        education  \
0     2006   18  1. Never Married  1. White     1. < HS Grad   
1     2004   24  1. Never Married  1. White  4. College Grad   
2     2003   45        2. Married  1. White  3. Some College   
3     2003   43        2. Married  3. Asian  4. College Grad   
4     2005   50       4. Divorced  1. White       2. HS Grad   
...    ...  ...               ...       ...              ...   
2995  2008   44        2. Married  1. White  3. Some College   
2996  2007   30        2. Married  1. White       2. HS Grad   
2997  2005   27        2. Married  2. Black     1. < HS Grad   
2998  2005   27  1. Never Married  1. White  3. Some College   
2999  2009   55      5. Separated  1. White       2. HS Grad   

                  region        jobclass          health health_ins   logwage  \
0     2. Middle Atlantic   1. Industrial       1. <=Good      2. No  4.318063   
1     2. Middle Atlantic  2. Information  2. >=Very Good      2. No  4.255273   
2     2. Middle Atlantic   1. Industrial       1. <=Good     1. Yes  4.875061   
3     2. Middle Atlantic  2. Information  2. >=Very Good     1. Yes  5.041393   
4     2. Middle Atlantic  2. Information       1. <=Good     1. Yes  4.318063   
...                  ...             ...             ...        ...       ...   
2995  2. Middle Atlantic   1. Industrial  2. >=Very Good     1. Yes  5.041393   
2996  2. Middle Atlantic   1. Industrial  2. >=Very Good      2. No  4.602060   
2997  2. Middle Atlantic   1. Industrial       1. <=Good      2. No  4.193125   
2998  2. Middle Atlantic   1. Industrial  2. >=Very Good     1. Yes  4.477121   
2999  2. Middle Atlantic   1. Industrial       1. <=Good     1. Yes  4.505150   

            wage  
0      75.043154  
1      70.476020  
2     130.982177  
3     154.685293  
4      75.043154  
...          ...  
2995  154.685293  
2996   99.689464  
2997   66.229408  
2998   87.981033  
2999   90.481913  

[3000 rows x 11 columns]
Wage.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3000 entries, 0 to 2999
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype   
---  ------      --------------  -----   
 0   year        3000 non-null   int64   
 1   age         3000 non-null   int64   
 2   maritl      3000 non-null   category
 3   race        3000 non-null   category
 4   education   3000 non-null   category
 5   region      3000 non-null   category
 6   jobclass    3000 non-null   category
 7   health      3000 non-null   category
 8   health_ins  3000 non-null   category
 9   logwage     3000 non-null   float64 
 10  wage        3000 non-null   float64 
dtypes: category(7), float64(2), int64(2)
memory usage: 115.2 KB
# summary statistics
Wage.describe(include = "all")
# Plot wage ~ age
sns.lmplot(
  data = Wage, 
  x = "age", 
  y = "wage", 
  lowess = True,
  scatter_kws = {'alpha' : 0.1},
  height = 8
  ).set(
  xlabel = 'Age', 
  ylabel = 'Wage (k$)'
  )

Figure 1: Wage changes nonlinearly with age.

# Plot wage ~ year
sns.lmplot(
  data = Wage, 
  x = "year", 
  y = "wage", 
  scatter_kws = {'alpha' : 0.1},
  height = 8
  ).set(
  xlabel = 'Year', 
  ylabel = 'Wage (k$)'
  )

Figure 2: Average wage increases by $10k in 2003-2009.

# Plot wage ~ education
ax = sns.boxplot(
  data = Wage, 
  x = "education", 
  y = "wage"
  )
ax.set(
  xlabel = 'Education', 
  ylabel = 'Wage (k$)'
  )
ax.set_xticklabels(ax.get_xticklabels(), rotation = 15)

Figure 3: Wage increases with education level.

# Plot wage ~ race
ax = sns.boxplot(
  data = Wage, 
  x = "race", 
  y = "wage"
  )
ax.set(
  xlabel = 'Race', 
  ylabel = 'Wage (k$)'
  )
ax.set_xticklabels(ax.get_xticklabels(), rotation = 15)

Figure 4: Any income inequality?

using AlgebraOfGraphics, CairoMakie, CSV, DataFrames
ENV["DATAFRAMES_COLUMNS"] = 1000
CairoMakie.activate!(type = "png")

# Import Wage data
Wage = DataFrame(CSV.File("../data/Wage.csv"))

# Summary statistics
describe(Wage)

# Plot Wage ~ age
data(Wage) * mapping(:age, :wage) * (smooth() + visual(Scatter)) |> draw

1.2.2 Example: stock market

Code
library(quantmod)

SP500 <- getSymbols(
  "^GSPC", 
  src = "yahoo", 
  auto.assign = FALSE, 
  from = "2022-01-01",
  to = "2022-12-31")

chartSeries(SP500, theme = chartTheme("white"),
            type = "line", log.scale = FALSE, TA = NULL)

  • The Smarket data set contains daily percentage returns for the S&P 500 stock index between 2001 and 2005.

  • Our goal is to predict whether the index will increase or decrease on a given day, using the past 5 days’ percentage changes in the index.

  • Since the outcome is binary (increase or decrease), it is a classification problem.

  • From the boxplots in Figure 5, it seems that the previous 5 days percentage returns do not discriminate whether today’s return is positive or negative.

# Data information
help(Smarket)

# Convert to tibble
Smarket <- as_tibble(Smarket) %>% print(width = Inf)
# A tibble: 1,250 × 9
    Year   Lag1   Lag2   Lag3   Lag4   Lag5 Volume  Today Direction
   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl> <fct>    
 1  2001  0.381 -0.192 -2.62  -1.06   5.01    1.19  0.959 Up       
 2  2001  0.959  0.381 -0.192 -2.62  -1.06    1.30  1.03  Up       
 3  2001  1.03   0.959  0.381 -0.192 -2.62    1.41 -0.623 Down     
 4  2001 -0.623  1.03   0.959  0.381 -0.192   1.28  0.614 Up       
 5  2001  0.614 -0.623  1.03   0.959  0.381   1.21  0.213 Up       
 6  2001  0.213  0.614 -0.623  1.03   0.959   1.35  1.39  Up       
 7  2001  1.39   0.213  0.614 -0.623  1.03    1.44 -0.403 Down     
 8  2001 -0.403  1.39   0.213  0.614 -0.623   1.41  0.027 Up       
 9  2001  0.027 -0.403  1.39   0.213  0.614   1.16  1.30  Up       
10  2001  1.30   0.027 -0.403  1.39   0.213   1.23  0.287 Up       
# … with 1,240 more rows
# Summary statistics
summary(Smarket)
      Year           Lag1                Lag2                Lag3          
 Min.   :2001   Min.   :-4.922000   Min.   :-4.922000   Min.   :-4.922000  
 1st Qu.:2002   1st Qu.:-0.639500   1st Qu.:-0.639500   1st Qu.:-0.640000  
 Median :2003   Median : 0.039000   Median : 0.039000   Median : 0.038500  
 Mean   :2003   Mean   : 0.003834   Mean   : 0.003919   Mean   : 0.001716  
 3rd Qu.:2004   3rd Qu.: 0.596750   3rd Qu.: 0.596750   3rd Qu.: 0.596750  
 Max.   :2005   Max.   : 5.733000   Max.   : 5.733000   Max.   : 5.733000  
      Lag4                Lag5              Volume           Today          
 Min.   :-4.922000   Min.   :-4.92200   Min.   :0.3561   Min.   :-4.922000  
 1st Qu.:-0.640000   1st Qu.:-0.64000   1st Qu.:1.2574   1st Qu.:-0.639500  
 Median : 0.038500   Median : 0.03850   Median :1.4229   Median : 0.038500  
 Mean   : 0.001636   Mean   : 0.00561   Mean   :1.4783   Mean   : 0.003138  
 3rd Qu.: 0.596750   3rd Qu.: 0.59700   3rd Qu.:1.6417   3rd Qu.: 0.596750  
 Max.   : 5.733000   Max.   : 5.73300   Max.   :3.1525   Max.   : 5.733000  
 Direction 
 Down:602  
 Up  :648  
           
           
           
           
# Plot Direction ~ Lag1, Direction ~ Lag2, ...
Smarket %>%
  pivot_longer(cols = Lag1:Lag5, names_to = "Lag", values_to = "Perc") %>%
  ggplot() + 
  geom_boxplot(mapping = aes(x = Direction, y = Perc)) +
  labs(
    x = "Today's Direction", 
    y = "Percentage change in S&P",
    title = "Up and down of S&P doesn't depend on previous day(s)'s percentage of change."
    ) +
  facet_wrap(~ Lag)

# Import S&P500 data
Smarket = pd.read_csv("../data/Smarket.csv")
Smarket
      Year   Lag1   Lag2   Lag3   Lag4   Lag5   Volume  Today Direction
0     2001  0.381 -0.192 -2.624 -1.055  5.010  1.19130  0.959        Up
1     2001  0.959  0.381 -0.192 -2.624 -1.055  1.29650  1.032        Up
2     2001  1.032  0.959  0.381 -0.192 -2.624  1.41120 -0.623      Down
3     2001 -0.623  1.032  0.959  0.381 -0.192  1.27600  0.614        Up
4     2001  0.614 -0.623  1.032  0.959  0.381  1.20570  0.213        Up
...    ...    ...    ...    ...    ...    ...      ...    ...       ...
1245  2005  0.422  0.252 -0.024 -0.584 -0.285  1.88850  0.043        Up
1246  2005  0.043  0.422  0.252 -0.024 -0.584  1.28581 -0.955      Down
1247  2005 -0.955  0.043  0.422  0.252 -0.024  1.54047  0.130        Up
1248  2005  0.130 -0.955  0.043  0.422  0.252  1.42236 -0.298      Down
1249  2005 -0.298  0.130 -0.955  0.043  0.422  1.38254 -0.489      Down

[1250 rows x 9 columns]
Smarket.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1250 entries, 0 to 1249
Data columns (total 9 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   Year       1250 non-null   int64  
 1   Lag1       1250 non-null   float64
 2   Lag2       1250 non-null   float64
 3   Lag3       1250 non-null   float64
 4   Lag4       1250 non-null   float64
 5   Lag5       1250 non-null   float64
 6   Volume     1250 non-null   float64
 7   Today      1250 non-null   float64
 8   Direction  1250 non-null   object 
dtypes: float64(7), int64(1), object(1)
memory usage: 88.0+ KB
# summary statistics
Smarket.describe(include = "all")
# Pivot to long format for facet plotting
Smarket_long = pd.melt(
  Smarket, 
  id_vars = ['Year', 'Volume', 'Today', 'Direction'], 
  value_vars = ['Lag1', 'Lag2', 'Lag3', 'Lag4', 'Lag5'],
  var_name = 'Lag',
  value_name = 'Perc'
  )
Smarket_long  
      Year   Volume  Today Direction   Lag   Perc
0     2001  1.19130  0.959        Up  Lag1  0.381
1     2001  1.29650  1.032        Up  Lag1  0.959
2     2001  1.41120 -0.623      Down  Lag1  1.032
3     2001  1.27600  0.614        Up  Lag1 -0.623
4     2001  1.20570  0.213        Up  Lag1  0.614
...    ...      ...    ...       ...   ...    ...
6245  2005  1.88850  0.043        Up  Lag5 -0.285
6246  2005  1.28581 -0.955      Down  Lag5 -0.584
6247  2005  1.54047  0.130        Up  Lag5 -0.024
6248  2005  1.42236 -0.298      Down  Lag5  0.252
6249  2005  1.38254 -0.489      Down  Lag5  0.422

[6250 rows x 6 columns]
g = sns.FacetGrid(Smarket_long, col = "Lag", col_wrap = 3, height = 10)
g.map_dataframe(sns.boxplot, x = "Direction", y = "Perc")

Figure 5: LagX is the percentage return for the previous X days.

plt.clf()

1.2.3 Example: handwritten digit recognition

Figure 6: Examples of handwritten digits from the MNIST corpus (ISL Figure 10.3).

  • Input: 784 pixel values from \(28 \times 28\) grayscale images. Output: 0, 1, …, 9, 10 class-classification.

  • On the MNIST data set (60,000 training images, 10,000 testing images), accuracies of following methods were reported:

    Method Error rate
    tangent distance with 1-nearest neighbor classifier 1.1%
    degree-9 polynomial SVM 0.8%
    LeNet-5 0.8%
    boosted LeNet-4 0.7%

1.2.4 Example: more computer vision tasks

Some popular data sets from computer vision.

1.3 Unsupervised learning

  • No outcome variable, just predictors.

  • Objective is more fuzzy: find groups that behave similarly, find features that behave similarly, find linear combinations of features with the most variations, generative models (transformers).

  • Difficult to know how well you are doing.

  • Can be useful in exploratory data analysis (EDA) or as a pre-processing step for supervised learning.

1.3.1 Example: gene expression

  • The NCI60 data set consists of 6,830 gene expression measurements for each of 64 cancer cell lines.
# NCI60 data and cancel labels
str(NCI60)
List of 2
 $ data: num [1:64, 1:6830] 0.3 0.68 0.94 0.28 0.485 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:64] "V1" "V2" "V3" "V4" ...
  .. ..$ : chr [1:6830] "1" "2" "3" "4" ...
 $ labs: chr [1:64] "CNS" "CNS" "CNS" "RENAL" ...
# Cancer type of each cell line
table(NCI60$labs)

     BREAST         CNS       COLON K562A-repro K562B-repro    LEUKEMIA 
          7           5           7           1           1           6 
MCF7A-repro MCF7D-repro    MELANOMA       NSCLC     OVARIAN    PROSTATE 
          1           1           8           9           6           2 
      RENAL     UNKNOWN 
          9           1 
# Apply PCA using prcomp function
# Need to scale / Normalize as
# PCA depends on distance measure
prcomp(NCI60$data, scale = TRUE, center = TRUE, retx = T)$x %>%
  as_tibble() %>%
  add_column(cancer_type = NCI60$labs) %>%
  # Plot PC2 vs PC1
  ggplot() + 
  geom_point(mapping = aes(x = PC1, y = PC2, color = cancer_type)) +
  labs(title = "Gene expression profiles cluster according to cancer types")

# Import NCI60 data
nci60_data = pd.read_csv('../data/NCI60_data.csv')
nci60_labs = pd.read_csv('../data/NCI60_labs.csv')
nci60_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 64 entries, 0 to 63
Columns: 6830 entries, 1 to 6830
dtypes: float64(6830)
memory usage: 3.3 MB
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale

# Obtain the first 2 principal components
nci60_tr = scale(nci60_data, with_mean = True, with_std = True)
nci60_pc = pd.DataFrame(
  PCA(n_components = 2).fit(nci60_tr).transform(nci60_tr),
  columns = ['PC1', 'PC2']
  )
nci60_pc['PC2'] *= -1  # for easier comparison with R
nci60_pc['cancer_type'] = nci60_labs
nci60_pc
          PC1        PC2 cancer_type
0  -19.838047   3.556383         CNS
1  -23.089213   6.443042         CNS
2  -27.456106   2.465472         CNS
3  -42.816781  -9.767787       RENAL
4  -55.418540  -5.201385      BREAST
..        ...        ...         ...
59 -17.996220 -47.240956    MELANOMA
60  -4.415503 -42.309758    MELANOMA
61 -22.966978 -36.101335    MELANOMA
62 -19.176017 -50.399104    MELANOMA
63 -13.232869 -35.124010    MELANOMA

[64 rows x 3 columns]
# Plot PC2 vs PC1
sns.relplot(
  kind = 'scatter', 
  data = nci60_pc, 
  x = 'PC1',
  y = 'PC2',
  hue = 'cancer_type',
  height = 10
  )

1.3.2 Example: mapping people from their genomes

  • The genetic makeup of \(n\) individuals can be represented by a matrix Equation 1, where \(x_{ij} \in \{0, 1, 2\}\) is the \(j\)-th genetic marker of the \(i\)-th individual.

    Is that possible to visualize the geographic relationship of these individuals?

  • Following picture is from the article Genes mirror geography within Europe by Novembre et al (2008) published in Nature.

1.3.3 Ancestry estimation

Figure 7: Unsupervised discovery of ancestry-informative markers and genetic admixture proportions. Paper.

1.4 No easy answer

In modern applications, the line between supervised and unsupervised learning is blurred.

1.4.1 Example: the Netflix prize

Figure 8: The Netflix challenge.

  • Competition started in Oct 2006. Training data is ratings for 480,189 Netflix customers \(\times\) 17,770 movies, each rating between 1 and 5.

  • Training data is very sparse, about 98% sparse.

  • The objective is to predict the rating for a set of 1 million customer-movie pairs that are missing in the training data.

  • Netflix’s in-house algorithm achieved a root MSE of 0.953. The first team to achieve a 10% improvement wins one million dollars.

  • Is this a supervised or unsupervised problem?

    • We can treat rating as outcome and user-movie combinations as predictors. Then it is a supervised learning problem.

    • Or we can treat it as a matrix factorization or low rank approximation problem. Then it is more of a unsupervised learning problem, similar to PCA.

1.4.2 Example: large language models (LLMs)

Modern large language models, such as ChatGPT3, combine both supervised learning and reinforcement learning.

1.5 Statistical learning vs machine learning

  • Machine learning arose as a subfield of Artificial Intelligence.

  • Statistical learning arose as a subfield of Statistics.

  • There is much overlap. Both fields focus on supervised and unsupervised problems.

    • Machine learning has a greater emphasis on large scale applications and prediction accuracy.

    • Statistical learning emphasizes models and their interpretability, and precision and uncertainty.

  • But the distinction has become more and more blurred, and there is a great deal of “cross-fertilization”.

  • Machine learning has the upper hand in Marketing!

1.6 A Brief History of Statistical Learning

Image source: https://people.idsia.ch/~juergen/deep-learning-history.html

  • 1676, chain rule by Leibniz.

  • 1805, least squares / linear regression / shallow learning by Gauss.

  • 1936, classification by linear discriminant analysis by Fisher.

  • 1940s, logistic regression.

  • Early 1970s, generalized linear models (GLMs).

  • Mid 1980s, classification and regression trees.

  • 1980s, generalized additive models (GAMs).

  • 1980s, neural networks gained popularity.

  • 1990s, support vector machines.

  • 2010s, deep learning.

1.7 Commonly used learning methods

  • Regression problems: linear regression (possibly with regularization and nonlinear features), linear mixed models, generalized additive model, KNN regression, regression tree, random forest, boosting, BART, neural network.

  • Classification problems: logistic regression (possibly with regularization and nonlinear features), discriminant analysis (LDA, QDA, NB), KNN classifier, classification tree, random forest, SVM, boosting, neural network.

  • Unsupervised learning: clustering, PCA, CCA, ICA, neural network (auto encoder).