Regresja logistyczna: straty i regularizacja

Modele regresji logistycznej są trenowane w taki sam sposób jak modele regresji liniowej, z 2 kluczowymi różnicami:

W sekcjach poniżej omówimy te 2 kwestie bardziej szczegółowo.

Logarytmiczna funkcja straty

W module regresji liniowej jako funkcji utraty użyto błędu kwadratowego (zwanego też błędem L2). Funkcja straty kwadratowej sprawdza się w przypadku modelu liniowego, w którym tempo zmian wartości wyjściowych jest stałe. Na przykład w przypadku modelu liniowego $y' = b + 3x_1$ za każdym razem, gdy zwiększysz wartość wejściową $x_1$ o 1, wartość wyjściowa $y'$ wzrośnie o 3.

Szybkość zmian w przypadku modelu regresji logistycznej nie jest jednak stała. Jak widać w artykule Obliczanie prawdopodobieństwa, krzywa sigmoid ma kształt litery S, a nie jest liniowa. Gdy wartość logitów ($z$) jest bliższa 0, niewielkie wzrosty $z$ powodują znacznie większe zmiany $y$ niż w przypadku, gdy $z$ jest dużą liczbą dodatnią lub ujemną. W tabeli poniżej przedstawiono wyniki funkcji sigmoidalnej dla wartości wejściowych od 5 do 10, a także odpowiednią precyzję wymaganą do uchwycenia różnic w wynikach.

wprowadzanie danych dane wyjściowe logistyczne, wymagane cyfry precyzji,
5 0,993 3
6 0,997 3
7 0,999 3
8 0,9997 4
9 0,9999 4
10 0,99998 5

Jeśli do obliczania błędów funkcji sigmoidalnej używasz błędu kwadratowego, w miarę zbliżania się wyniku do wartości 01 potrzebujesz więcej pamięci, aby zachować precyzję niezbędną do śledzenia tych wartości.

Zamiast tego funkcja straty w przypadku regresji logistycznej to logarytmiczna funkcja straty. Równanie funkcji straty logarytmicznej zwraca logarytm wielkości zmiany, a nie tylko odległość od danych do prognozy. Funkcja Log Loss jest obliczana w ten sposób:

$\text{Log Loss} = -\frac{1}{N}\sum_{i=1}^{N} y_i\log(y_i') + (1 - y_i)\log(1 - y_i')$

gdzie:

  • \(N\) to liczba oznaczonych przykładów w zbiorze danych.
  • \(i\) to indeks przykładu w zbiorze danych (np. \((x_3, y_3)\) to trzeci przykład w zbiorze danych)
  • \(y_i\) to etykieta \(i\)-tego przykładu. Ponieważ jest to regresja logistyczna, \(y_i\) musi mieć wartość 0 lub 1.
  • \(y_i'\) to prognoza modelu dla \(i\)-tego przykładu (wartość z przedziału od 0 do 1) na podstawie zestawu cech w  \(x_i\).

Regularyzacja w regresji logistycznej

Regularyzacja, czyli mechanizm karania złożoności modelu podczas trenowania, jest niezwykle ważna w modelowaniu regresji logistycznej. Bez regularyzacji asymptotyczny charakter regresji logistycznej powodowałby, że w przypadku modeli z dużą liczbą cech funkcja straty zbliżałaby się do 0. W związku z tym większość modeli regresji logistycznej korzysta z jednej z tych 2 strategii, aby zmniejszyć złożoność modelu:

  • Regularyzacja L2
  • Wczesne zatrzymanie: ograniczenie liczby kroków trenowania, aby zatrzymać trenowanie, gdy strata nadal maleje.