Perdita funzioni

I GAN tentano di replicare una distribuzione di probabilità. Pertanto, devono utilizzare funzioni di perdita che riflettono la distanza tra la distribuzione dei dati generati dal GAN e la distribuzione dei dati reali.

Come si acquisisce la differenza tra due distribuzioni nelle funzioni di perdita GAN? Questa domanda è un'area di ricerca attiva e sono stati proposti molti approcci. Analizzeremo qui due funzioni comuni di perdita di GAN, entrambe implementate in TF-GAN:

TF-GAN implementa anche molte altre funzioni di perdita.

Una o due funzioni di perdita?

Un GAN può avere due funzioni di perdita: una per l'addestramento del generatore e una per l'addestramento di discriminazione. Come possono funzionare due funzioni di perdita per riflettere una misura di distanza tra le distribuzioni di probabilità?

Negli schemi di perdita che vedremo qui, le perdite di generatori e discriminatori derivano da un'unica misura di distanza tra le distribuzioni di probabilità. In entrambi questi schemi, tuttavia, il generatore può influire su un solo termine nella misura della distanza, ovvero il termine che riflette la distribuzione dei dati falsi. Durante l'addestramento del generatore, tralasciamo l'altro termine, che riflette la distribuzione dei dati reali.

Le perdite del generatore e del discriminatore alla fine hanno un aspetto diverso, anche se derivano da un'unica formula.

Perdita minima

Nell'articolo che ha introdotto i GAN, il generatore cerca di ridurre al minimo la seguente funzione mentre il discriminatore cerca di massimizzarla:

$$E_x[log(D(x))] + E_z[log(1 - D(G(z)))]$$

In questa funzione:

  • D(x) è la stima del discriminatore relativa alla probabilità che l'istanza di dati reali x sia reale.
  • X è il valore previsto per tutte le istanze di dati reali.
  • G(z) è l'output del generatore quando viene dato il rumore z.
  • D(G(z)) è la stima del discriminatore relativa alla probabilità che un'istanza falsa sia reale.
  • Ez è il valore previsto per tutti gli input casuali nel generatore (in effetti, il valore previsto per tutte le istanze false generate G(z)).
  • La formula deriva dalla cross-entropy tra le distribuzioni reali e generate.

Il generatore non può influire direttamente sul termine log(D(x)) nella funzione, quindi per il generatore, ridurre al minimo la perdita equivale a ridurre al minimo log(1 - D(G(z))).

In TF-GAN, vedi minimax_discriminator_loss e minimax_generator_loss per un'implementazione di questa funzione di perdita.

Perdita minima modificata

L'articolo originale della GAN rileva che la funzione di perdita di minimax precedente può far sì che il GAN si blocchi nelle prime fasi dell'addestramento del GAN, quando il lavoro del discriminatore è molto facile. Il documento suggerisce quindi di modificare la perdita generatore in modo che il generatore tenti di massimizzare log D(G(z)).

In TF-GAN, consulta modified_generator_loss per un'implementazione di questa modifica.

Perdita di Wasserstein

Per impostazione predefinita, TF-GAN utilizza la perdita di Wasserstein.

Questa funzione di perdita dipende da una modifica dello schema GAN (noto come "Wasserstein GAN" o "WGAN"), in cui il discriminatore non classifica effettivamente le istanze. Per ogni istanza restituisce un numero. Questo numero non deve essere inferiore a uno o maggiore di 0, quindi non possiamo utilizzare 0,5 come soglia per decidere se un'istanza è reale o falsa. L'addestramento da Discriminatore tenta solo di aumentare l'output per le istanze reali rispetto a quelle false.

Poiché non può davvero discriminare il vero e il falso, il discriminatore WGAN è in realtà chiamato "discriminatore"; invece di un "discriminatore". Questa distinzione ha un'importanza teorica, ma, per scopi pratici, possiamo considerarla un riconoscimento che gli input per le funzioni di perdita non devono essere probabilità.

Le stesse funzioni di perdita sono ingannevolmente semplici:

Perdita della critica: D(x) - D(G(z))

Il discriminatore cerca di massimizzare questa funzione. In altre parole, cerca di massimizzare la differenza tra il suo output su istanze reali e quello su istanze false.

Perdita del generatore: D(G(z))

Il generatore cerca di massimizzare questa funzione. In altre parole, cerca di massimizzare l'output del discriminatore per le sue false istanze.

In queste funzioni:

  • D(x) è l'output della critica per un'istanza reale.
  • G(z) è l'output del generatore quando viene dato il rumore z.
  • D(G(z)) è l'output della critica per un'istanza falsa.
  • L'output della critica D non deve essere compreso tra 1 e 0.
  • Le formule derivano dalla distanza del movimento di Earth tra le distribuzioni reali e generate.

In TF-GAN, vedi wasserstein_generator_loss e wasserstein_discriminator_loss per le implementazioni.

Requisiti

La giustificazione teorica per il GAN di Wasserstein (o WGAN) richiede che i pesi all'interno del GAN vengano ritagliati in modo che rimangano entro un intervallo limitato.

Vantaggi

I GAN di Wasserstein sono meno vulnerabili a essere bloccati rispetto ai GAN basati su minimax ed evita problemi con gradienti sfumati. La distanza dello spostamento della terra ha anche il vantaggio di essere una vera metrica: una misura della distanza in uno spazio di distribuzioni di probabilità. L'entropia incrociata non è una metrica in questo senso.