Modele regresji logistycznej są trenowane w taki sam sposób jak modele regresji liniowej, ale z 2 kluczowymi różnicami:
- Modele regresji logistycznej używają logarytmicznej funkcji straty zamiast straty średniokwadratowej.
- Stosowanie regularyzacji jest niezbędne, aby zapobiec przetrenowaniu.
W kolejnych sekcjach omówimy te 2 kwestie bardziej szczegółowo.
Logarytmiczna funkcja straty
W module Regresja liniowa jako funkcji straty używaliśmy straty średniokwadratowej (zwanej też stratą L2). Strata średniokwadratowa dobrze sprawdza się w przypadku modelu liniowego, w którym tempo zmian wartości wyjściowych jest stałe. Na przykład w modelu liniowym $y' = b + 3x_1$ za każdym razem, gdy zwiększysz wartość wejściową $x_1$ o 1, wartość wyjściowa $y'$ wzrośnie o 3.
Jednak tempo zmian modelu regresji logistycznej nie jest stałe. Jak widzisz w sekcji Obliczanie prawdopodobieństwa, krzywa sigmoid ma kształt litery S a nie jest liniowa. Gdy wartość logarytmu szans ($z$) jest bliższa 0, niewielkie wzrosty $z$ powodują znacznie większe zmiany $y$ niż wtedy, gdy $z$ jest dużą liczbą dodatnią lub ujemną. W tabeli poniżej znajdziesz dane wyjściowe funkcji sigmoidowej dla wartości wejściowych od 5 do 10 oraz wymaganą precyzję, aby uchwycić różnice w wynikach.
| dane wejściowe | dane wyjściowe regresji logistycznej | wymagana precyzja |
|---|---|---|
| 5 | 0,993 | 3 |
| 6 | 0,997 | 3 |
| 7 | 0,999 | 3 |
| 8 | 0,9997 | 4 |
| 9 | 0,9999 | 4 |
| 10 | 0,99998 | 5 |
Jeśli do obliczania błędów funkcji sigmoidowej używasz straty średniokwadratowej, to gdy dane wyjściowe zbliżają się do 0 i 1, będziesz potrzebować więcej pamięci, aby zachować precyzję niezbędną do śledzenia tych wartości.
Zamiast tego funkcja straty w regresji logistycznej to logarytmiczna funkcja straty. Równanie logarytmicznej funkcji straty zwraca logarytm wielkości zmiany, a nie tylko odległość od danych do prognozy. Logarytmiczna funkcja straty jest obliczana w ten sposób:
$\text{Log Loss} = -\frac{1}{N}\sum_{i=1}^{N} [y_i\log(y_i') + (1 - y_i)\log(1 - y_i')]$
gdzie:
- \(N\) to liczba oznaczonych przykładów w zbiorze danych,
- \(i\) to indeks przykładu w zbiorze danych (np. \((x_3, y_3)\) to trzeci przykład w zbiorze danych),
- \(y_i\) to etykieta przykładu \(i\). Ponieważ jest to regresja logistyczna,musi mieć wartość 0 lub 1. \(y_i\)
- \(y_i'\) to prognoza modelu dla \(i\)przykładu (wartość między 0 a 1) na podstawie zestawu cech w \(x_i\).
Regularyzacja w regresji logistycznej
Regularyzacja, czyli mechanizm karania złożoności modelu podczas trenowania, jest niezwykle ważna w modelowaniu regresji logistycznej. Bez regularyzacji asymptotyczny charakter regresji logistycznej powodowałby, że strata zbliżałaby się do 0 w przypadkach, gdy model ma dużą liczbę cech. W związku z tym większość modeli regresji logistycznej używa jednej z tych 2 strategii, aby zmniejszyć złożoność modelu:
- regularyzacja L2
- wczesne zatrzymanie: ograniczenie liczby kroków trenowania, aby zatrzymać trenowanie, gdy strata nadal się zmniejsza.