blog/content/posts/2023-10-26-vae의-손실-함수.md

46 lines
3.0 KiB
Markdown
Raw Normal View History

---
title: VAE의 손실 함수
date: 2022-12-15T11:21:46.836Z
slug: VAE의-손실-함수
---
VAE의 손실 함수는 Variational Lower Bound(VLB)라고도 불리는데, 이 함수는 VAE에서 사용되는 정보량을 측정하는데 사용됩니다. 이 함수는 두 부분으로 나뉘어져 있으며, 이를 분리하여 계산하는 것이 일반적입니다. 이 때, reparametrization trick이란 기법을 사용하면 VAE의 손실 함수의 그래디언트를 계산하는 데 도움이 될 수 있습니다.
먼저, VAE의 손실 함수는 아래와 같이 정의됩니다.
$$L = \mathbb{E}{q(z|x)}\[\log p(x|z)] - D{KL}(q(z|x)||p(z))$$
여기서 $p(x|z)$는 인코더의 출력을 디코더의 입력으로 넣었을 때 생성되는 값을 의미하고, $p(z)$는 잠재 변수의 확률 분포를 의미합니다. $q(z|x)$는 인코더의 출력을 잠재 변수의 확률 분포로 추정한 것을 의미하며, $D_{KL}$은 클로저-라이브러리 발산을 의미합니다.
이를 분리하여 계산하면 아래와 같이 나뉩니다.
$$L = \mathbb{E}{q(z|x)}\[\log p(x|z)] - D{KL}(q(z|x)||p(z))$$
$$= \mathbb{E}{q(z|x)}\[\log p(x|z)] - \mathbb{E}{q(z|x)}\[\log \frac{q(z|x)}{p(z)}]$$
이때, reparametrization trick을 이용하면 손실 함수의 그래디언트를 쉽게 계산할 수 있습니다. 이 기법은 잠재 변수의 확률 분포를 정의할 때, 정규 분포의 모수를 함수로 표현하는 것을 의미합니다. 이렇게 하면 잠재 변수의 확률 분포를 명시적으로 정의할 수 있어 그래디언트를 계산하기 쉬워집니다.
예를 들어, $q(z|x)$를 아래와 같이 정의할 수 있습니다.
$$q(z|x) = \mathcal{N}(\mu(x), \sigma^2(x))$$
여기서 $\mu(x)$와 $\sigma^2(x)$는 인코더의 출력을 이용해 계산된 정규 분포의 모수입니다. 이렇게 정의된 $q(z|x)$를 이용하면 손실 함수의 그래디언트를 아래와 같이 계산할 수 있습니다.
$$\frac{\partial L}{\partial \theta} = \frac{\partial}{\partial \theta}\mathbb{E}{q(z|x)}\[\log p(x|z)] - \frac{\partial}{\partial \theta}\mathbb{E}{q(z|x)}\[\log \frac{q(z|x)}{p(z)}]$$
$$= \mathbb{E}{q(z|x)}\left\[\frac{\partial}{\partial \theta}\log p(x|z)\right] - \mathbb{E}{q(z|x)}\left\[\frac{\partial}{\partial \theta}\log \frac{q(z|x)}{p(z)}\right]$$
이때, 위 식의 첫 번째 항을 전개하면 다음과 같습니다.
$$\mathbb{E}_{q(z|x)}\left\[\frac{\partial}{\partial \theta}\log p(x|z)\right] = \int q(z|x)\frac{\partial}{\partial \theta}\log p(x|z) dz$$
여기서 $\theta$는 VAE에서 사용되는 모든 매개변수를 의미합니다. 따라서 이 식을 이용하면 VAE의 손실 함수의 그래디언트를 계산할 수 있습니다.
예를 들어, 손실 함수가 $L = \mathbb{E}{q(z|x)}\[\log p(x|z)] - D{KL}(q(z|x)||p(z))$인 경우에는 첫 번째 항의 그래디언트는 다음과 같이 계산할 수 있습니다.
$$\frac{\partial}{\partial \theta}\mathbb{E}{q(z|x)}\[\log p(x|z)] = \mathbb{E}{q(z|x)}\left\[\frac{\partial}{\partial \theta}\log p(x|z)\right]$$