I modelli di regressione logistica vengono addestrati utilizzando lo stesso processo dei modelli di regressione lineare, con due distinzioni chiave:
- I modelli di regressione logistica utilizzano la perdita logaritmica come funzione di perdita anziché la perdita quadratica.
- L'applicazione della regolarizzazione è fondamentale per prevenire l' overfitting.
Le sezioni seguenti illustrano queste due considerazioni in modo più approfondito.
Perdita logaritmica
Nel modulo Regressione lineare, hai utilizzato la perdita quadratica (chiamata anche perdita L2) come funzione di perdita. La perdita quadratica funziona bene per un modello lineare in cui il tasso di variazione dei valori di output è costante. Ad esempio, dato il modello lineare $y' = b + 3x_1$, ogni volta che incrementi il valore di input $x_1$ di 1, il valore di output $y'$ aumenta di 3.
Tuttavia, il tasso di variazione di un modello di regressione logistica non è costante. Come hai visto in Calcolare una probabilità, la sigmoid ha una forma a S anziché lineare. Quando il valore di log-odds ($z$) è più vicino a 0, piccoli aumenti di $z$ comportano variazioni molto maggiori di $y$ rispetto a quando $z$ è un numero positivo o negativo elevato. La tabella seguente mostra l'output della funzione sigmoide per i valori di input da 5 a 10, nonché la precisione corrispondente necessaria per acquisire le differenze nei risultati.
| immissione | output logistico | cifre di precisione richieste |
|---|---|---|
| 5 | 0,993 | 3 |
| 6 | 0,997 | 3 |
| 7 | 0,999 | 3 |
| 8 | 0,9997 | 4 |
| 9 | 0,9999 | 4 |
| 10 | 0,99998 | 5 |
Se hai utilizzato la perdita quadratica per calcolare gli errori per la funzione sigmoide, man mano che l'output si avvicinava sempre più a 0 e 1, avresti bisogno di più memoria per mantenere la precisione necessaria per monitorare questi valori.
Invece, la funzione di perdita per la regressione logistica è la perdita logaritmica. L'equazione della perdita logaritmica restituisce il logaritmo della magnitudo della variazione, anziché solo la distanza dai dati alla previsione. La perdita logaritmica viene calcolata come segue:
$\text{Perdita logaritmica} = -\frac{1}{N}\sum_{i=1}^{N} [y_i\log(y_i') + (1 - y_i)\log(1 - y_i')]$
dove:
- \(N\) è il numero di esempi etichettati nel set di dati
- \(i\) è l'indice di un esempio nel set di dati (ad es. \((x_3, y_3)\) è il terzo esempio nel set di dati)
- \(y_i\) è l'etichetta per l' \(i\)esempio. Poiché si tratta di una regressione logistica,deve essere 0 o 1. \(y_i\)
- \(y_i'\) è la previsione del modello per l' \(i\)esempio (un valore compreso tra 0 e 1), dato l'insieme di funzionalità in \(x_i\).
Regolarizzazione nella regressione logistica
La **regolarizzazione** , un meccanismo per penalizzare la complessità del modello durante l'addestramento, è estremamente importante nella modellazione della regressione logistica. Senza la regolarizzazione, la natura asintotica della regressione logistica continuerebbe a portare la perdita verso 0 nei casi in cui il modello ha un numero elevato di funzionalità. Di conseguenza, la maggior parte dei modelli di regressione logistica utilizza una delle due strategie seguenti per ridurre la complessità del modello:
- Regolarizzazione L2
- Interruzione anticipata: limitare il numero di passaggi di addestramento per interrompere l'addestramento mentre la perdita è ancora in diminuzione.