Modele regresji logistycznej są trenowane w taki sam sposób jak modele regresji liniowej, z 2 kluczowymi różnicami:
- Modele regresji logistycznej używają logarytmicznej funkcji straty zamiast kwadratowej funkcji straty.
- Stosowanie regularyzacji jest kluczowe, aby zapobiec przetrenowaniu.
W sekcjach poniżej omówimy te 2 kwestie bardziej szczegółowo.
Logarytmiczna funkcja straty
W module regresji liniowej jako funkcji utraty użyto błędu kwadratowego (zwanego też błędem L2). Funkcja straty kwadratowej sprawdza się w przypadku modelu liniowego, w którym tempo zmian wartości wyjściowych jest stałe. Na przykład w przypadku modelu liniowego $y' = b + 3x_1$ za każdym razem, gdy zwiększysz wartość wejściową $x_1$ o 1, wartość wyjściowa $y'$ wzrośnie o 3.
Szybkość zmian w przypadku modelu regresji logistycznej nie jest jednak stała. Jak widać w artykule Obliczanie prawdopodobieństwa, krzywa sigmoid ma kształt litery S, a nie jest liniowa. Gdy wartość logitów ($z$) jest bliższa 0, niewielkie wzrosty $z$ powodują znacznie większe zmiany $y$ niż w przypadku, gdy $z$ jest dużą liczbą dodatnią lub ujemną. W tabeli poniżej przedstawiono wyniki funkcji sigmoidalnej dla wartości wejściowych od 5 do 10, a także odpowiednią precyzję wymaganą do uchwycenia różnic w wynikach.
wprowadzanie danych | dane wyjściowe logistyczne, | wymagane cyfry precyzji, |
---|---|---|
5 | 0,993 | 3 |
6 | 0,997 | 3 |
7 | 0,999 | 3 |
8 | 0,9997 | 4 |
9 | 0,9999 | 4 |
10 | 0,99998 | 5 |
Jeśli do obliczania błędów funkcji sigmoidalnej używasz błędu kwadratowego, w miarę zbliżania się wyniku do wartości 0
i 1
potrzebujesz więcej pamięci, aby zachować precyzję niezbędną do śledzenia tych wartości.
Zamiast tego funkcja straty w przypadku regresji logistycznej to logarytmiczna funkcja straty. Równanie funkcji straty logarytmicznej zwraca logarytm wielkości zmiany, a nie tylko odległość od danych do prognozy. Funkcja Log Loss jest obliczana w ten sposób:
$\text{Log Loss} = -\frac{1}{N}\sum_{i=1}^{N} y_i\log(y_i') + (1 - y_i)\log(1 - y_i')$
gdzie:
- \(N\) to liczba oznaczonych przykładów w zbiorze danych.
- \(i\) to indeks przykładu w zbiorze danych (np. \((x_3, y_3)\) to trzeci przykład w zbiorze danych)
- \(y_i\) to etykieta \(i\)-tego przykładu. Ponieważ jest to regresja logistyczna, \(y_i\) musi mieć wartość 0 lub 1.
- \(y_i'\) to prognoza modelu dla \(i\)-tego przykładu (wartość z przedziału od 0 do 1) na podstawie zestawu cech w \(x_i\).
Regularyzacja w regresji logistycznej
Regularyzacja, czyli mechanizm karania złożoności modelu podczas trenowania, jest niezwykle ważna w modelowaniu regresji logistycznej. Bez regularyzacji asymptotyczny charakter regresji logistycznej powodowałby, że w przypadku modeli z dużą liczbą cech funkcja straty zbliżałaby się do 0. W związku z tym większość modeli regresji logistycznej korzysta z jednej z tych 2 strategii, aby zmniejszyć złożoność modelu:
- Regularyzacja L2
- Wczesne zatrzymanie: ograniczenie liczby kroków trenowania, aby zatrzymać trenowanie, gdy strata nadal maleje.