The Story of Joon
Fast Fourier Transform 본문
고속 푸리에 변환(Fast Fourier Transform, FFT)은 convolution을 $O(N\log N)$에 구할 때 활용된다. 이 포스트에서는 코드 자체보다도 FFT 알고리즘의 원리를 알아보는 것이 목적이다. 코드만 보고싶다면 맨 아래로 내려가면 된다.
푸리에 변환은 수학에서 매우 중요한 개념이며 공학 분야에서도 중요하게 다루어진다. 일반적으로 푸리에 변환은 함수를 함수로 보내는 변환인데, PS에서 활용되는 푸리에 변환은 수열을 수열로 보내는 이산 푸리에 변환(Discrete Fourier Transform, DFT)이다.
주기가 $N$인 수열 $\lbrace a_j\rbrace_{j=0}^{N-1}$의 DFT $\lbrace A_n\rbrace_{n=0}^{N-1}$은 다음과 같이 정의된다. $e^{i\theta}=\cos\theta+i\sin\theta$로 정의됨에 유의하자.
\[A_n = \sum_{j=0}^{N-1} e^{-2\pi i j n / N}a_j\]
(일반적인) 푸리에변환에는 좋은 성질이 있는데, 역변환을 구하기 쉽다는 것이다. $\lbrace A_n\rbrace_{n=0}^{N-1}$의 이산 푸리에 역변환(Inverse Discrete Fourier Transform)은 다음과 같이 계산한다. 증명은 직접 대입해보면 알 수 있으므로 여기서는 생략한다.
\[a_n = \sum_{j=0}^{N-1} \frac{1}{N}e^{2\pi i j n / N}A_j\]
즉 DFT를 구하는 것과 IDFT를 구하는 것은 방법이 거의 같다고도 말할 수 있다. 여기서 $e^{2\pi i}=1$이라는 사실을 이용하면 위 식을 다음과 같이 적을 수도 있다.
\[a_n = \sum_{j=0}^{N-1}\frac{1}{N}e^{-2\pi i j n / N}A_{N-j}\]
첫 번째 항에 나타나는 $A_N$은, 수열의 주기가 $N$이므로 $A_0$와 같다. 즉 어떤 수열의 IDFT를 구하기 위해 2번째 항(0-based로 index 1)부터 마지막 항까지를 뒤집어준 뒤 DFT를 취하고, 각 항을 $N$으로 나눠줄 수도 있다.
또 푸리에변환의 좋은 성질 중 하나는 convolution을 곱으로 바꿔준다는 것이다. 수학적 표현으로는 다음과 같다.
\[\mathcal{F}\lbrace \mathbf{a} * \mathbf{b}\rbrace = \mathcal{F}\lbrace \mathbf{a}\rbrace\mathcal{F}\lbrace \mathbf{b}\rbrace\]
주기가 $N$인 수열 $\lbrace a_j\rbrace_{j=0}^{N-1}$와 $\lbrace b_j\rbrace_{j=0}^{N-1}$의 convolution $\lbrace c_n\rbrace_{n=0}^{N-1}$은 다음과 같이 정의된다.
\[c_n = \sum_{j=0}^{N-1}a_j b_{n - j}\]
정의를 보면, $b$의 첨자가 음수가 될 수도 있다는 사실이 이상할 수도 있는데, DFT에서 모든 수열은 주기가 $N$이다. 따라서 $b_{-k}$는 $b_{N-k}$와 같다. convolution에 DFT를 취하면 곱이 된다는 사실 역시 원 식에 대입하면 확인할 수 있다. 증명은 생략한다.
여기서 눈여겨볼 점은, $\lbrace c_n\rbrace_{n=0}^{N-1}$을 직접 계산하면 $O(N^2)$의 시간이 걸리지만, DFT를 빠르게 구할 수 있다면 두 수열에 DFT를 취한 뒤 각 원소를 곱하고 IDFT를 취하면 convolution을 얻을 수 있으므로 $O(N^2)$보다 빠른 시간에 원하는 결과를 얻을 수 있다. 고속 푸리에 변환(Fast Fourier Transform)은 DFT를 $O(N\log N)$에 구하는 알고리즘이다.
FFT에는 여러 종류가 있지만, 가장 잘 알려져 있고 구현하기 편한 것은 Cooley-Tukey Algorithm이다. 이 방법은 $N$이 2의 거듭제곱인 경우에만 사용할 수 있지만, 임의의 크기라도 2의 거듭제곱으로 늘이면 되기 때문에 그다지 큰 걱정은 아니다.
먼저 $\lbrace a_j\rbrace_{j=0}^{N-1}$의 푸리에 변환 식을 조금 변형해 보자.
\[\begin{aligned}A_n &= \sum_{j=0}^{N-1} e^{-2\pi i j n / N}a_j \\ &= \sum_{j=0}^{N/2-1}e^{-2\pi i (2j) n / N}a_{2j} + \sum_{j=0}^{N/2-1}e^{-2\pi i (2j+1)n / N}a_{2j+1} \\ &= \sum_{j=0}^{N/2-1}e^{-2\pi i j n / (N/2)}a_{2j} + e^{-2\pi i n / N}\sum_{j=0}^{N/2-1}e^{-2\pi i j n / (N / 2)}a_{2j+1}\end{aligned}\]
우측에 익숙한 식이 두 개 보인다. 이들은 각각 짝수번째 항들의 DFT와 홀수번째 항들의 DFT이다. 이것들은 $n$ 대신 $n+N/2$이 들어가도 변하지 않는 값이다. $n$ 대신 $n+N/2$이 들어갔을 때 변하는 값은 $e^{-2\pi i n / N}$밖에 없고, 부호만 바뀐다는 것을 알 수 있다. 즉 $n<N/2$일 때 다음과 같이 쓸 수 있다.
\[A_{n+N/2} = \sum_{j=0}^{N/2-1}e^{-2\pi i j n / (N/2)}a_{2j} - e^{-2\pi i n / N}\sum_{j=0}^{N/2-1}e^{-2\pi i j n / (N / 2)}a_{2j+1}\]
여기서 중요한 점은, 짝수번째 항들의 DFT와 홀수번째 항들의 DFT를 알고 있다면, 전체 DFT를 계산하는 데 $O(N)$의 시간이면 충분하다는 것이다. 따라서 $O(N\log N)$에 DFT를 계산할 수 있다.
typedef complex<double> base;
void fft(vector<base> &a, vector<base> &A) {
int n = (int) a.size();
if (n == 1) {
A[0] = a[0];
return;
}
vector<base> even(n / 2), odd(n / 2), Even(n / 2), Odd(n / 2);
for (int i = 0; i < n / 2; i++) {
even[i] = a[2 * i];
odd[i] = a[2 * i + 1];
}
fft(even, Even);
fft(odd, Odd);
double th = -2.0 * M_PI / n;
base w = base(cos(th), sin(th));
base z = base(1);
for (int i = 0; i < n / 2; i++) {
A[i] = Even[i] + z * Odd[i];
A[i + n / 2] = Even[i] - z * Odd[i];
z *= w;
}
}
void ifft(vector<base> &A, vector<base> &a) {
reverse(++A.begin(), A.end());
fft(A, a);
int n = (int) a.size();
for (int i = 0; i < n; i++) {
a[i] /= n;
}
}
위 구현에서 한 가지 문제점이 있다면, 추가적인 벡터의 생성이 너무 많다는 것이다. 물론 위 구현으로도 문제들을 푸는 데 지장은 없지만, 이왕이면 좀더 효율적으로 FFT를 수행하는 방법을 찾아보자.
일단 $N=8$인 경우부터 살펴보도록 하자. 짝수번째 항들을 앞쪽으로 모으고 홀수번째 항들을 뒤쪽으로 모아보자.
여기서 앞 4개 항과 뒤 4개 항 각각에 FFT를 행해 보자.
이제 위 식에서와 같이 $E_i$와 $O_i$를 조합하면 $A_i$와 $A_{i+N}$를 만들 수 있고 위치도 같다. 따라서 $E_i$와 $O_i$를 모두 구했다면, $O(1)$의 추가 메모리로 DFT를 구하는 데 충분하다. 즉, inplace로 연산을 수행할 수 있다.
그렇다면 앞뒤 4개 항의 DFT를 구할 때는 어떻게 할까? 같은 방법으로 각각 안에서 홀수 항과 짝수항을 앞뒤로 민다.
이제 $\lbrace a_0, a_4\rbrace$ 두 항으로 이루어진 수열과 $\lbrace a_2, a_6\rbrace$ 두 항으로 이루어진 수열에 FFT를 수행하고 앞 문단의 방식과 같이 inplace로 조합하면, 수열 $\lbrace a_0, a_2, a_4, a_6\rbrace$의 DFT를 구할 수 있다. 그러고 나면 앞 문단의 그림으로 환원되므로, 같은 방식으로 진행해 $O(1)$보다 큰 메모리 없이 전체 배열의 FFT를 완료할 수 있다.
이런 방식을 재귀적으로 적용하면 임의의 2의 거듭제곱 $N$에 대하여 다음과 같은 결론을 얻을 수 있다: 만일 처음 배열을 적당히 잘 재배열해 놓는다면, 그 다음부터는 $O(1)$보다 큰 추가 메모리 없이 전체 배열의 FFT를 완료할 수 있다. 그렇다면 재배열을 어떻게 하면 될까? 재배열하는 과정에서 $a_k$의 움직임을 관찰해보자.
- $k$가 짝수라면, $a_k$는 앞 $N/2$ 크기의 배열로 들어가고, 홀수라면 뒤 $N/2$ 크기의 배열로 들어간다.
- 새로 들어간 배열에서 인덱스가 짝수라면 앞 $N/4$ 크기의 배열, 홀수라면 뒤 $N/4$ 크기의 배열로 들어간다.
- 새로 들어간 배열의 크기가 1이 될 때까지 반복.
이 때 새로 들어간 배열에서의 인덱스는, 이전 배열에서의 인덱스에서 2를 나눈 몫이라는 것을 쉽게 알 수 있다. 더 쉬운 이해를 위해 $N=8$일 때 $a_5$의 움직임을 관찰해 보자.
이러한 방식으로 모든 항을 이동시키고 나면 재배열이 완료된다. 이후로는 1칸씩, 2칸씩, 4칸씩 FFT 수행 후 inplace 조합 과정을 반복하면 $O(1)$보다 큰 추가 메모리 없이 전체 배열의 FFT를 완료할 수 있다.
typedef complex<double> base;
void fft(vector<base> &a, bool inv) {
int n = (int) a.size();
vector<base> b = a;
for (int i = 0; i < n; i++) {
int sz = n, shift = 0, idx = i;
while (sz > 1) {
if (idx & 1) shift += sz >> 1;
idx >>= 1;
sz >>= 1;
}
a[shift + idx] = b[i];
}
for (int i = 1; i < n; i <<= 1) {
double x = inv ? M_PI / i : -M_PI / i;
base w = {cos(x), sin(x)};
for (int j = 0; j < n; j += i << 1) {
base th = {1, 0};
for (int k = 0; k < i; k++) {
base tmp = a[i + j + k] * th;
a[i + j + k] = a[j + k] - tmp;
a[j + k] += tmp;
th *= w;
}
}
}
if (inv) {
for (int i = 0; i < n; i++) {
a[i] /= n;
}
}
}
이제 유일한 오점은 복사 배열을 하나 더 만들어야 한다는 것이다. 그런데 $a_k$의 움직임을 잘 관찰하면 중요한 성질을 하나 알 수 있는데, $N=2^n$일 때 재배열 후 $a_k$의 위치는 $k$를 $n$자리 이진법 전개를 했을 때 $k$의 bit reversal이라는 것이다. 예를 들어 $N=8$이고 $k=5$라면, 5의 세 자리 이진법 전개는 $101$이므로 bit reversal인 $101$, 즉 자기 자신인 5가 재배열 후 위치가 된다. 만일 $N=16$이고 $k=11$이면, $1011$의 bit reversal인 $1101$, 즉 13이 재배열 후 위치가 된다. 결국 $k$의 재배열 후 위치가 $l$이라면 $l$의 재배열 후 위치는 $k$가 되고, 복사 배열을 생성할 필요 없이 swap만으로 재배열을 완료할 수 있게 된다. 재배열 후 위치가 이러한 성질을 갖게 되는 증명은 어렵지 않으므로 독자 여러분에게 맡긴다.
즉, $O(1)$의 추가 메모리로 $O(N\log N)$ 시간에 비재귀로 FFT를 수행할 수 있게 되었다.
typedef complex<double> base;
void fft(vector<base> &a, bool inv) {
int n = (int) a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
while (!((j ^= bit) & bit)) bit >>= 1;
if (i < j) swap(a[i], a[j]);
}
for (int i = 1; i < n; i <<= 1) {
double x = inv ? M_PI / i : -M_PI / i;
base w = {cos(x), sin(x)};
for (int j = 0; j < n; j += i << 1) {
base th = {1, 0};
for (int k = 0; k < i; k++) {
base tmp = a[i + j + k] * th;
a[i + j + k] = a[j + k] - tmp;
a[j + k] += tmp;
th *= w;
}
}
}
if (inv) {
for (int i = 0; i < n; i++) {
a[i] /= n;
}
}
}
마지막으로 PS에서 FFT가 어떻게 활용될 수 있는지 알아보자. 2의 거듭제곱이 아닐 수도 있는 임의의 $N$에 대해, $\lbrace a_j\rbrace_{j=0}^{N-1}$와 $\lbrace b_j\rbrace_{j=0}^{N-1}$에 대해 다음을 만족하는 $\lbrace c_n\rbrace_{n=0}^{N-1}$을 구하는 문제를 생각해 보자.
\[c_n = \sum_{j=0}^{n}a_jb_{n-j}\]
이 포스트의 앞부분에서 정의한 convolution과 거의 같지만, $n-j$가 음수인 경우의 값은 원하지 않는다는 차이가 있다. 이 때문에 우리는 $\lbrace a_j\rbrace_{j=0}^{N-1}$와 $\lbrace b_j\rbrace_{j=0}^{N-1}$의 크기를 2배 이상으로 늘인 후 새로 생긴 곳에 모두 0을 넣는다. 또한 지금까지 이야기한 FFT 알고리즘인 Cooley-Tukey Algorithm은 배열의 크기가 2의 거듭제곱일 때만 작동하므로, 우리는 배열의 크기를 원래 크기의 2배 이상인 2의 거듭제곱까지 늘이게 된다. 이후 이 포스트 앞부분에서와 같이 FFT를 한 후 각 항을 곱하고 IFFT를 수행하면, 수열 $\lbrace c_n\rbrace_{n=0}^{N-1}$의 모든 항을 $O(N\log N)$ 시간에 구할 수 있게 된다.
void multiply(vector<base> &a, vector<base> &b) {
int n = (int) max(a.size(), b.size());
int i = 0;
while ((1 << i) < (n << 1)) i++;
n = 1 << i;
a.resize(n);
b.resize(n);
fft(a, false);
fft(b, false);
for (int i = 0; i < n; i++) {
a[i] *= b[i];
}
fft(a, true);
}
연습문제
'Computer Science > 알고리즘' 카테고리의 다른 글
Linear Algebra in Problem Solving (1) (0) | 2022.08.31 |
---|---|
2015 ACM-ICPC 한국 예선 F - 파일 합치기 (2) | 2020.05.02 |
Link/Cut Tree (2) (7) | 2017.08.21 |
Link/Cut Tree (1) (2) | 2017.08.13 |
Splay Tree (0) | 2017.08.13 |