13  Introduction to Vector Models and Word Embeddings

This is a quick walk-through tutorial on using vector models. This lab only uses word2vec, which is old but has the advantage of being simple and easy to train. It’s possible to get embeddings from modern large language models, but this requires installing a lot more software (and is much easier in Python).

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.2
✔ ggplot2   4.0.0     ✔ tibble    3.3.0
✔ lubridate 1.9.4     ✔ tidyr     1.3.1
✔ purrr     1.1.0     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(gt)

library(Rtsne)
library(word2vec)
library(cmu.textstat)
library(quanteda.extras)

13.1 Basics

The word2vec package implements the original word2vec training algorithm. Let’s run it on our corpus of Shakespeare plays, extracting specifically the dialogue.

dialogue <- from_play(shakespeare_corpus, "dialogue")

model <- word2vec(txt_clean_word2vec(dialogue$text))

Now we can get embeddings of words using predict():

predict(model, c("dog", "cat"),
        type = "embedding")
          [,1]        [,2]      [,3]      [,4]     [,5]      [,6]       [,7]
dog -0.5098643  0.53995699 -1.993794 0.6061975 1.410652 -2.649742 -1.3353217
cat  1.2188730 -0.01844489 -1.276233 1.3571930 1.724212 -1.161104 -0.8368803
          [,8]        [,9]     [,10]     [,11]      [,12]     [,13]     [,14]
dog -0.6944243  0.09920318 0.1054260 2.3284216  0.5488601 0.6050486 0.8371658
cat -0.5321332 -0.49536505 0.4908282 0.2522839 -0.3601401 0.1683761 0.8911011
        [,15]     [,16]      [,17]      [,18]      [,19]      [,20]    [,21]
dog 0.4661073 2.2059679 -0.1652774 -0.4280483 -0.2417707 0.08867547 1.772947
cat 1.3904195 0.6284767 -0.6533063  1.7856417 -1.2417134 0.40575781 1.753314
        [,22]     [,23]      [,24]      [,25]     [,26]      [,27]      [,28]
dog 0.6164772 0.3357099  0.4201811 -0.4059037 0.1537911  0.7708796  0.3270563
cat 0.3174724 0.7973865 -0.5650601 -1.8752923 0.4176239 -0.6458973 -1.7958031
        [,29]     [,30]      [,31]      [,32]    [,33]      [,34]     [,35]
dog 0.1674979 1.5566996 -0.2702732 -1.7267250 1.400280 -0.1310500 0.4482224
cat 0.5578048 0.5685212  0.3117453 -0.4207658 2.249357  0.5000364 0.5453086
         [,36]     [,37]    [,38]        [,39]      [,40]      [,41]     [,42]
dog -0.2651205 0.8030650 1.263019 -0.007785887 -0.3386081 -0.7048467 0.5294110
cat -1.0539590 0.5069932 2.146533 -0.116481327  0.3608209  0.4059952 0.4047288
         [,43]      [,44]     [,45]       [,46]       [,47]      [,48]
dog -0.6498324 -0.5800671 0.9896889  0.06131002  0.67012358 -1.7186267
cat -0.7977300  0.6647797 0.4395330 -0.37249580 -0.09999175 -0.2669219
        [,49]     [,50]
dog 0.4203046 0.8646228
cat 1.7030052 1.3296981

We can see it defaulted to 50-dimensional embeddings. We can do higher dimensions, and large language models usually do, but using those dimensions effectively requires having an enormous training corpus.

We can also get the entire embedding matrix. Using it, we can measure the cosine similarity between words, for instance to find the 10 most similar words to “cat” and “dog”:

embeds <- as.matrix(model)

word2vec_similarity(embeds[c("cat", "dog"), ], embeds, top_n = 10, 
                    type = "cosine")
   term1     term2 similarity rank
1    dog       dog  1.0000000    1
2    dog     devil  0.7710401    2
3    dog      fool  0.7596142    3
4    dog     rogue  0.7479254    4
5    dog     slave  0.7454997    5
6    dog    rascal  0.7430351    6
7    dog    coward  0.7369859    7
8    dog     knave  0.7257963    8
9    dog       fox  0.7231458    9
10   dog      lion  0.7170312   10
11   cat       cat  1.0000000    1
12   cat loathsome  0.7944970    2
13   cat     adder  0.7931900    3
14   cat   crooked  0.7910794    4
15   cat     snake  0.7764300    5
16   cat      hoop  0.7742539    6
17   cat     agony  0.7705589    7
18   cat     river  0.7700824    8
19   cat    accent  0.7666846    9
20   cat     patch  0.7622123   10

For a more substantive comparison, let’s look at the most similar words to “queen” and “prince” to see how the embeddings capture semantic meaning:

word2vec_similarity(embeds[c("queen", "prince"), ], embeds, top_n = 10, 
                    type = "cosine")
    term1      term2 similarity rank
1   queen      queen  1.0000000    1
2   queen  katharine  0.7844194    2
3   queen     sister  0.7806268    3
4   queen   daughter  0.7560874    4
5   queen       lady  0.7448757    5
6   queen   princess  0.7186794    6
7   queen      widow  0.7118343    7
8   queen   margaret  0.7017737    8
9   queen    kinsman  0.6886260    9
10  queen      nurse  0.6884649   10
11 prince     prince  1.0000000    1
12 prince  gentleman  0.7810150    2
13 prince     knight  0.7567070    3
14 prince       duke  0.7507963    4
15 prince        son  0.7489808    5
16 prince gloucester  0.7360526    6
17 prince    warwick  0.7359573    7
18 prince       king  0.7339212    8
19 prince      gaunt  0.7320203    9
20 prince    proteus  0.7317890   10

Shakespeare is fun, but to explore the embeddings of other texts, we should probably create word embeddings based on more modern English. Let’s use the Brown corpus. Training the embeddings is easy if you have the corpus.

Don’t run this code chunk! Instead, skip to the next chunk where we read in the saved embeddings, which you can download from Canvas and save in the same directory as this lab.

library(quanteda.extras)

brown <- readtext_lite(list.files("../data/brown_corpus/", pattern = ".*.txt", full.names = TRUE))
brown_model <- word2vec(txt_clean_word2vec(brown$text))
write.word2vec(brown_model, "brown.bin")

Here we load the saved pre-made embeddings:

brown_model <- read.word2vec("brown.bin")

13.2 Example from the MICUSP data

Let’s try embedding the essays in MICUSP. The doc2vec() function calculates the average embedding for all words in a document, so we can get a single embedding for the whole document. First we preprocess the text:

df <- micusp_mini |>
  mutate(text = preprocess_text(text, remove_numbers = TRUE))

Now let’s get the embeddings of each essay:

doc_embeddings <- doc2vec(brown_model, df)

We can also use word2vec_similarity() to find the closest documents to a chosen document, using cosine similarity.

closest_docs <- function(doc_id, embeddings) {
  sims <- word2vec_similarity(embeddings, embeddings[doc_id, , drop = FALSE])
  
  data.frame(
    doc_id = rownames(sims),
    similarity = unname(sims[, 1])
  ) |> arrange(desc(similarity))
}

For a biology essay:

closest_docs("BIO.G0.02.1", doc_embeddings) |> gt()
doc_id similarity
BIO.G0.02.1 1.0000000
PSY.G1.01.1 0.9951459
PSY.G0.09.1 0.9948925
LIN.G0.06.1 0.9948847
LIN.G0.08.1 0.9944884
NUR.G3.05.1 0.9938122
HIS.G1.04.1 0.9937205
POL.G0.34.1 0.9928873
LIN.G1.06.1 0.9927297
IOE.G1.05.1 0.9922452
EDU.G1.01.1 0.9920583
SOC.G0.02.1 0.9919623
ECO.G0.08.1 0.9916758
POL.G0.06.1 0.9914741
LIN.G0.13.1 0.9914252
IOE.G1.03.1 0.9913340
HIS.G1.07.1 0.9912008
NUR.G0.15.1 0.9911703
BIO.G3.02.1 0.9911289
NUR.G3.06.1 0.9909121
POL.G0.29.1 0.9908648
SOC.G1.08.1 0.9907378
HIS.G1.01.1 0.9906191
ECO.G2.08.1 0.9905454
POL.G0.32.1 0.9904380
NUR.G3.03.1 0.9902657
POL.G0.18.1 0.9900011
POL.G1.02.1 0.9898275
SOC.G0.07.1 0.9897349
BIO.G0.29.1 0.9895331
HIS.G2.04.1 0.9895292
HIS.G3.02.1 0.9894605
LIN.G3.02.1 0.9894507
CLS.G0.06.1 0.9893541
EDU.G1.07.1 0.9893107
CLS.G1.01.1 0.9892079
SOC.G0.01.1 0.9891555
LIN.G0.02.1 0.9891551
POL.G0.14.1 0.9890789
LIN.G1.01.1 0.9890224
NUR.G0.11.1 0.9888998
PSY.G0.21.1 0.9888743
PHI.G0.14.1 0.9888590
PHY.G2.04.1 0.9887810
EDU.G3.01.1 0.9887530
BIO.G2.02.1 0.9887437
NUR.G0.01.1 0.9887309
BIO.G0.06.1 0.9886712
NRE.G0.06.1 0.9886337
HIS.G1.02.1 0.9884310
NUR.G0.13.1 0.9883829
HIS.G1.06.1 0.9882821
BIO.G0.03.1 0.9882575
PSY.G1.12.1 0.9881184
PHY.G2.03.1 0.9879849
BIO.G0.25.1 0.9878469
IOE.G1.06.1 0.9876692
CLS.G2.01.1 0.9875180
ECO.G1.03.1 0.9873205
CEE.G0.03.1 0.9872937
IOE.G1.02.1 0.9871181
CLS.G1.04.1 0.9870877
CLS.G1.02.1 0.9870618
SOC.G3.08.1 0.9868915
HIS.G0.02.1 0.9867959
CLS.G2.02.1 0.9866598
ECO.G2.07.1 0.9866329
EDU.G3.04.1 0.9866258
PHY.G2.05.1 0.9864709
PSY.G2.09.1 0.9864278
PHY.G2.01.1 0.9863083
SOC.G2.03.1 0.9862617
PSY.G1.13.1 0.9862500
NUR.G1.04.1 0.9862418
PSY.G0.38.1 0.9862093
ECO.G0.02.1 0.9861420
CEE.G1.09.1 0.9860989
PHY.G2.06.1 0.9860406
NRE.G3.02.1 0.9856553
HIS.G0.03.1 0.9854763
CEE.G2.01.1 0.9852465
POL.G0.11.1 0.9852131
PHY.G3.03.1 0.9851176
ECO.G2.05.1 0.9850903
ENG.G0.24.1 0.9850879
POL.G0.47.1 0.9850758
ENG.G2.01.1 0.9850103
PSY.G2.10.1 0.9849998
HIS.G1.05.1 0.9849282
NRE.G2.05.1 0.9846413
PHI.G0.03.1 0.9846278
EDU.G1.12.1 0.9845703
IOE.G0.06.1 0.9845488
NRE.G1.26.1 0.9845165
POL.G0.36.1 0.9844026
MEC.G0.07.1 0.9843858
PHY.G3.02.1 0.9843844
ECO.G2.02.1 0.9843700
NRE.G2.07.1 0.9843318
PHI.G0.13.1 0.9842364
ENG.G2.02.1 0.9842017
SOC.G1.09.1 0.9841123
PSY.G3.03.1 0.9833004
MEC.G0.08.1 0.9829831
SOC.G3.07.1 0.9829470
CEE.G0.04.1 0.9829450
PHI.G0.16.1 0.9828866
BIO.G2.03.1 0.9828725
ECO.G2.03.1 0.9825524
PHI.G2.02.1 0.9825170
IOE.G0.07.1 0.9824385
MEC.G3.02.1 0.9823809
PHI.G0.05.1 0.9822476
PHY.G2.07.1 0.9819858
IOE.G0.08.1 0.9816702
BIO.G0.21.1 0.9815446
PHI.G2.01.1 0.9813735
ENG.G0.16.1 0.9813024
MEC.G2.04.1 0.9812961
CEE.G3.02.1 0.9810268
LIN.G0.12.1 0.9804746
CLS.G0.04.1 0.9803861
LIN.G0.10.1 0.9801645
CEE.G3.01.1 0.9799189
NUR.G1.06.1 0.9797544
ECO.G2.04.1 0.9797010
IOE.G1.01.1 0.9795800
MEC.G0.06.1 0.9795071
MEC.G0.12.1 0.9793496
EDU.G0.14.1 0.9793218
SOC.G1.01.1 0.9792962
PHY.G1.02.1 0.9787151
CLS.G0.03.1 0.9785898
PHI.G0.11.1 0.9781966
IOE.G0.09.1 0.9777886
PHY.G0.03.1 0.9777664
CEE.G1.02.1 0.9777442
NRE.G1.18.1 0.9776972
EDU.G0.05.1 0.9775301
LIN.G0.11.1 0.9773118
MEC.G0.05.1 0.9772308
PSY.G0.35.1 0.9767839
CEE.G1.03.1 0.9762433
ENG.G0.38.1 0.9760813
NRE.G0.03.1 0.9759861
CEE.G3.03.1 0.9758444
CEE.G1.04.1 0.9755810
PHI.G0.10.1 0.9746650
NRE.G2.08.1 0.9746343
BIO.G0.12.1 0.9739422
ENG.G0.58.1 0.9731019
IOE.G3.01.1 0.9727766
MEC.G0.04.1 0.9716000
MEC.G1.02.1 0.9712169
NRE.G0.08.1 0.9708104
ENG.G0.50.1 0.9706776
ECO.G0.07.1 0.9702517
EDU.G0.02.1 0.9694454
NRE.G0.09.1 0.9672718
ENG.G0.21.1 0.9644772
CLS.G0.02.1 0.9644207
MEC.G1.05.1 0.9636580
CLS.G0.05.1 0.9635091
EDU.G0.13.1 0.9628293
SOC.G0.13.1 0.9588756
ENG.G2.04.1 0.9576720
EDU.G0.15.1 0.9572651
NUR.G2.01.1 0.9570471
ENG.G0.55.1 0.9486978
PHI.G0.08.1 0.9364008

For an English essay:

closest_docs("ENG.G2.04.1", doc_embeddings) |> gt()
doc_id similarity
ENG.G2.04.1 1.0000000
ENG.G0.21.1 0.9938062
ENG.G0.50.1 0.9898749
ENG.G0.38.1 0.9889000
ENG.G0.55.1 0.9887549
ENG.G0.58.1 0.9875708
CLS.G0.04.1 0.9864097
CLS.G0.03.1 0.9861072
PHI.G0.10.1 0.9849536
CLS.G0.05.1 0.9837379
CLS.G2.02.1 0.9830468
ENG.G2.01.1 0.9825920
PSY.G0.35.1 0.9821215
ENG.G0.16.1 0.9797931
SOC.G0.13.1 0.9791657
CLS.G2.01.1 0.9791412
ENG.G0.24.1 0.9791363
CLS.G0.02.1 0.9786253
PHI.G0.03.1 0.9780630
HIS.G1.05.1 0.9767803
CLS.G1.01.1 0.9759044
LIN.G0.11.1 0.9740923
CLS.G1.02.1 0.9736200
HIS.G1.02.1 0.9733014
CLS.G0.06.1 0.9727637
PHI.G2.02.1 0.9727364
EDU.G0.13.1 0.9727179
PHI.G2.01.1 0.9721851
HIS.G1.06.1 0.9713326
HIS.G1.04.1 0.9713113
PHI.G0.13.1 0.9712344
LIN.G0.02.1 0.9703281
POL.G0.47.1 0.9703115
CLS.G1.04.1 0.9698493
IOE.G1.03.1 0.9690642
LIN.G0.13.1 0.9689053
PHI.G0.05.1 0.9680860
PHI.G0.11.1 0.9679450
POL.G0.32.1 0.9670790
BIO.G0.29.1 0.9659656
EDU.G0.15.1 0.9657514
ENG.G2.02.1 0.9653313
PHI.G0.14.1 0.9649760
PSY.G0.09.1 0.9647026
PSY.G2.09.1 0.9635009
EDU.G0.05.1 0.9623973
POL.G0.36.1 0.9621209
HIS.G2.04.1 0.9620071
SOC.G0.02.1 0.9613501
HIS.G1.07.1 0.9610684
NUR.G3.05.1 0.9609887
PHI.G0.08.1 0.9604159
NUR.G1.04.1 0.9601002
NUR.G2.01.1 0.9598143
NUR.G0.11.1 0.9597986
PSY.G3.03.1 0.9588860
LIN.G0.10.1 0.9584930
POL.G0.11.1 0.9580384
NUR.G0.15.1 0.9580304
EDU.G0.02.1 0.9580185
LIN.G0.08.1 0.9579684
BIO.G0.02.1 0.9576720
HIS.G1.01.1 0.9560261
HIS.G3.02.1 0.9555855
EDU.G0.14.1 0.9554924
SOC.G0.01.1 0.9551780
POL.G0.29.1 0.9551237
LIN.G1.06.1 0.9549761
NUR.G0.01.1 0.9537730
HIS.G0.03.1 0.9532848
HIS.G0.02.1 0.9530458
LIN.G0.06.1 0.9530351
BIO.G0.25.1 0.9529937
BIO.G2.02.1 0.9528757
PHI.G0.16.1 0.9523912
IOE.G1.02.1 0.9521001
LIN.G1.01.1 0.9520022
PSY.G2.10.1 0.9515881
POL.G0.06.1 0.9515375
SOC.G3.08.1 0.9513748
SOC.G1.09.1 0.9511586
ECO.G2.08.1 0.9506328
PSY.G1.01.1 0.9505923
IOE.G1.05.1 0.9494972
NRE.G0.06.1 0.9494472
PSY.G1.12.1 0.9493038
POL.G1.02.1 0.9489140
EDU.G3.01.1 0.9487330
LIN.G3.02.1 0.9483996
SOC.G2.03.1 0.9479657
SOC.G1.08.1 0.9475024
EDU.G1.07.1 0.9465865
NUR.G3.06.1 0.9464080
POL.G0.34.1 0.9462346
EDU.G1.12.1 0.9459404
SOC.G1.01.1 0.9458061
EDU.G3.04.1 0.9456685
PSY.G0.21.1 0.9456172
MEC.G0.07.1 0.9455271
EDU.G1.01.1 0.9452302
ECO.G2.07.1 0.9448987
MEC.G0.08.1 0.9446099
PSY.G0.38.1 0.9445340
IOE.G1.01.1 0.9443292
PHY.G2.03.1 0.9441623
NRE.G3.02.1 0.9439761
IOE.G1.06.1 0.9439362
POL.G0.14.1 0.9438557
ECO.G0.02.1 0.9435879
PHY.G2.06.1 0.9435840
PHY.G2.05.1 0.9434910
BIO.G0.03.1 0.9427155
MEC.G0.06.1 0.9418119
PHY.G2.01.1 0.9416451
BIO.G0.06.1 0.9412876
POL.G0.18.1 0.9411500
PSY.G1.13.1 0.9400539
NUR.G0.13.1 0.9391824
ECO.G0.08.1 0.9388519
PHY.G2.04.1 0.9385574
NRE.G2.05.1 0.9384653
NUR.G3.03.1 0.9376985
SOC.G0.07.1 0.9365294
ECO.G2.02.1 0.9364563
CEE.G0.04.1 0.9356244
PHY.G3.03.1 0.9355105
IOE.G0.06.1 0.9353402
NRE.G1.26.1 0.9345932
MEC.G2.04.1 0.9343930
CEE.G1.02.1 0.9318789
BIO.G3.02.1 0.9313039
CEE.G2.01.1 0.9311977
CEE.G0.03.1 0.9305848
CEE.G3.02.1 0.9303575
MEC.G3.02.1 0.9303525
SOC.G3.07.1 0.9300634
ECO.G2.05.1 0.9299913
LIN.G0.12.1 0.9295496
NRE.G2.08.1 0.9294500
CEE.G1.09.1 0.9290903
ECO.G1.03.1 0.9290700
NRE.G0.03.1 0.9289563
NRE.G1.18.1 0.9285390
NUR.G1.06.1 0.9284931
PHY.G2.07.1 0.9279413
NRE.G2.07.1 0.9250516
IOE.G0.07.1 0.9243273
IOE.G0.08.1 0.9242550
PHY.G3.02.1 0.9222510
CEE.G3.01.1 0.9213190
MEC.G0.12.1 0.9209858
IOE.G0.09.1 0.9209774
ECO.G2.03.1 0.9205263
MEC.G1.02.1 0.9203339
MEC.G0.05.1 0.9174214
PHY.G0.03.1 0.9167812
PHY.G1.02.1 0.9161647
BIO.G0.21.1 0.9155640
CEE.G1.03.1 0.9151729
ECO.G2.04.1 0.9146912
MEC.G0.04.1 0.9142510
BIO.G2.03.1 0.9133021
NRE.G0.09.1 0.9123970
NRE.G0.08.1 0.9104201
BIO.G0.12.1 0.9090798
CEE.G1.04.1 0.9065789
ECO.G0.07.1 0.9064843
CEE.G3.03.1 0.9055792
MEC.G1.05.1 0.8934154
IOE.G3.01.1 0.8928683

13.3 Plotting

We can reduce the 50-dimensional embeddings down to 2 dimensions. A simple way to do this is PCA, but PCA only can do linear combinations of variables. Instead, we’ll try t-SNE, which does nonlinear dimension reduction and can handle more complex structures.

doc_tsne <- Rtsne(doc_embeddings, check_duplicates = FALSE, pca = FALSE,
                  perplexity = 5, theta = 0.5, dims = 2)

doc_tsne <- data.frame(
  doc_id = rownames(doc_embeddings),
  subject = substr(rownames(doc_embeddings), 1, 3),
  X1 = doc_tsne$Y[, 1],
  X2 = doc_tsne$Y[, 2]
)
library(ggrepel)

ggplot(doc_tsne, aes(x = X1, y = X2, color = subject)) +
  geom_point() +
  # geom_text_repel(aes(label = doc_id)) +
  theme_bw()

Embeddings of MICUSP Mini corpus.

Do you notice patterns? Do subjects group together? In two dimensions any separation will be less stark than in 50 dimensions, but it is still impressive that the embeddings capture this structure.

13.4 Multi-class classification

We can also use the embeddings in a classification task. Here we have 50 variables, 170 documents, and 17 classes. That’s not a lot of data, but perhaps this demonstrates how we’d approach classification in a much larger dataset.

We’ll create a random forest, following our previous lab on random forests. This will take a moment to run:

library(caret)
Loading required package: lattice

Attaching package: 'caret'
The following object is masked from 'package:purrr':

    lift
trainset <- data.frame(doc_embeddings) |> 
  mutate(subject = substr(rownames(doc_embeddings), 1, 3))

train(subject ~ ., data = trainset, 
      method = "rf", trControl = trainControl(savePredictions = "final"))
Random Forest 

170 samples
 50 predictor
 17 classes: 'BIO', 'CEE', 'CLS', 'ECO', 'EDU', 'ENG', 'HIS', 'IOE', 'LIN', 'MEC', 'NRE', 'NUR', 'PHI', 'PHY', 'POL', 'PSY', 'SOC' 

No pre-processing
Resampling: Bootstrapped (25 reps) 
Summary of sample sizes: 170, 170, 170, 170, 170, 170, ... 
Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
   2    0.3766522  0.3413517
  26    0.3482943  0.3111559
  50    0.3424216  0.3048238

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 2.

Review the accuracy results reported. They are not very high, but how do they compare to simple random guessing?

13.5 Example from the Federalist Papers

Let’s use the embeddings of the Federalist Papers to try to study the authorship problem. First we clean the data and get the embeddings:

fed_txt <- federalist_papers |>
  left_join(federalist_meta) |>
  filter(author_id %in% c("Hamilton", "Madison", "Disputed")) |>
  mutate(text = preprocess_text(text, remove_numbers = TRUE))

fed_embeddings <- doc2vec(brown_model, fed_txt)

For kicks, let’s do t-SNE again and see if the authors are separated in the embeddings:

fed_tsne <- Rtsne(fed_embeddings, check_duplicates = FALSE, pca = FALSE,
                  perplexity = 5, theta = 0.5, dims = 2)

fed_tsne <- data.frame(
  doc_id = rownames(fed_embeddings),
  author = fed_txt$author_id,
  X1 = fed_tsne$Y[, 1],
  X2 = fed_tsne$Y[, 2]
)

ggplot(fed_tsne, aes(x = X1, y = X2, color = author)) +
  geom_point() +
  theme_bw()

Of course, the plot is reducing 50-dimensional space down to 2, so maybe they’d be easier to separate in 50 dimensions. We can test that by fitting a model. First, let’s sample training and test data:

train <- fed_txt |>
  filter(author_id %in% c("Hamilton", "Madison"))

  # group_by(author_id) %>%
  # sample_n(14) %>%
  # ungroup()

test <- fed_txt |>
  filter(author_id == "Disputed")

We’ll use lasso, like we did in the original lab:

library(glmnet)

train_embeddings <- fed_embeddings[train$doc_id, ]
train_labels <- train$author_id

cv_fit <- cv.glmnet(train_embeddings, train_labels, family = "binomial")

We could print out the non-zero coefficients, but the embeddings are not interpretable, so it’s not very meaningful to know that embeddding 27 is zero and 28 is not.

13.5.1 Create a matrix from the test set and predict author

test_embeddings <- fed_embeddings[test$doc_id, ]

lasso_prob <- predict(cv_fit, newx = test_embeddings, type = "response")

13.5.2 Check results

data.frame(Paper = rownames(lasso_prob), Prob = lasso_prob[, 1]) |>
  mutate(Author = ifelse(lasso_prob > 0.5, "Madison", "Hamilton")) |>
  gt() |>
  fmt_number(columns = "Prob",
             decimals = 2)
Paper Prob Author
FEDERALIST_49 0.51 Madison
FEDERALIST_50 0.92 Madison
FEDERALIST_51 0.87 Madison
FEDERALIST_52 0.75 Madison
FEDERALIST_53 0.87 Madison
FEDERALIST_54 0.68 Madison
FEDERALIST_55 0.86 Madison
FEDERALIST_56 0.75 Madison
FEDERALIST_57 0.50 Madison
FEDERALIST_58 0.86 Madison
FEDERALIST_62 0.44 Hamilton
FEDERALIST_63 0.69 Madison

Of course, we don’t expect this to match perfectly. Embeddings capture the subject and meaning of the papers, not their writing style, so we’re really using the topics to classify them, not stylometry.

Compare this to your results from the original Federalist Papers lab.