Модели логистической регрессии обучаются с использованием того же процесса, что и модели линейной регрессии , с двумя ключевыми отличиями:
- Модели логистической регрессии используют логарифм потерь в качестве функции потерь вместо квадрата потерь .
- Применение регуляризации имеет решающее значение для предотвращения переобучения .
В следующих разделах эти два соображения обсуждаются более подробно.
Логарифм потерь
В модуле «Линейная регрессия» в качестве функции потерь вы использовали квадратичную функцию потерь (также называемую функцией потерь L2 ). Квадратичная функция потерь хорошо подходит для линейной модели с постоянной скоростью изменения выходных значений. Например, в линейной модели $y' = b + 3x_1$ каждый раз, когда входное значение $x_1$ увеличивается на 1, выходное значение $y'$ увеличивается на 3.
Однако скорость изменения модели логистической регрессии непостоянна . Как вы видели в разделе «Расчёт вероятности» , сигмоидальная кривая имеет s-образную, а не линейную форму. Когда значение логарифма шансов ($z$) близко к 0, небольшое увеличение $z$ приводит к гораздо более значительным изменениям $y$, чем когда $z$ — большое положительное или отрицательное число. В следующей таблице представлены выходные данные сигмоидальной функции для входных значений от 5 до 10, а также соответствующая точность, необходимая для учёта различий в результатах.
вход | логистический вывод | требуемые цифры точности |
---|---|---|
5 | 0,993 | 3 |
6 | 0,997 | 3 |
7 | 0,999 | 3 |
8 | 0,9997 | 4 |
9 | 0,9999 | 4 |
10 | 0,99998 | 5 |
Если вы используете квадратичные потери для вычисления ошибок сигмоидальной функции, то по мере того, как выходные данные становятся все ближе к 0
и 1
, вам понадобится больше памяти для сохранения точности, необходимой для отслеживания этих значений.
Вместо этого функция потерь для логистической регрессии — логарифм потерь . Уравнение логарифма потерь возвращает логарифм величины изменения, а не просто расстояние между данными и прогнозом. Логарифм потерь рассчитывается следующим образом:
$\text{Логарифм потерь} = -\frac{1}{N}\sum_{i=1}^{N} y_i\log(y_i') + (1 - y_i)\log(1 - y_i')$