티스토리 뷰

Study/AI

[RL] Linear TD

생각많은 소심남 2019. 11. 12. 22:00

(해당 포스트는 Coursera의 Prediction and Control with Function Approximation의 강의 요약본입니다)

 이전 포스트들을 통해서 설명하고자 했던 것은 기존의 Q-table과 같은 Tabular 방식이 아닌, Value를 하나의 Function, 즉 Value Function으로 근사하는 방법이 존재하고, 이때 이 근사된 Value Function과 실제 Value Function과의 오차를 줄일 수 있는 방법으로 Gradient Descent를 적용할 수 있다는 것이었다. 그래서 Function Approximation을 Monte Carlo에 적용한 Gradient MC과 제한적이기는 하나, Gradient를 TD Learning에 적용한 Semi-Gradient TD에 대해서 살펴보았다.

 우선 Semi Gradient TD에서 TD update가 되는 부분을 다시 살펴보자. 일단 TD Error를 다음과 같이 \(\delta_t\) 정의했을때,

$$ \delta_t \doteq R_{t+1} + \gamma \hat{v}(S_{t+1}, \mathbf{w}) - \hat{v}(S_t, \mathbf{w}) $$

근사된 Value Function과 실제 Value Function간의 오차를 줄이는 방향으로 다음과 같이 weight이 update된다.

$$ \mathbf{w} \leftarrow \mathbf{w} + \alpha \delta_t \nabla \hat{v}(S_t, \mathbf{w}) $$

그런데 만약 우리가 Function Approximation 중 Linear Function Approximation, 즉 weight과 feature vector간의 inner product된 형태라고 가정을 하면, 근사된 value function의 gradient는 feature vector 그 자신이 된다.

$$ \hat{v}(S_t, \mathbf{w}) \doteq \mathbf{w}^T \mathbf{x}(S_t) \\ \nabla \hat{v}(S_t, \mathbf{w}) = \mathbf{x}(S_t) $$

이 관계를 위의 TD update 식에 넣으면 다음과 같은 관계로 정의할 수 있게 된다.

$$ \mathbf{w} \leftarrow \mathbf{w} + \alpha \delta_t \mathbf{x}(S_t) $$

 여기서 한가지 유념할것은 결국 value error를 줄이는데 feature vector \(\mathbf{x}(s_t)\)가 직접적으로 연관이 있다는 것이다. 그 말은 다시 말해 처음 value function을 근사할때, feature vector만 잘 설계한다면, 얼마든지 error를 줄이는 방향으로 weight이 update된다는 것이다.

 그런데 여기까지 왔을때, 앞에서 다룬 Tabular TD 방식과 지금 다루는 Linear Function Approximation을 한 TD (Linear TD)와 무슨 차이가 있을까 궁금할 수 있다. 사실 feature vector를 아래와 같이

$$ \mathbf{x}(s_i) = \begin{bmatrix} 0 \\ 0 \\ \dots \\ 0 \\ \color{blue}{1} \\ 0 \\ \dots \\ 0 \end{bmatrix} $$

 특정 state에서만 1을 갖게 하는 vector로 정의를 하게되면, 근사화된 Value function은 말그대로 1을 가진 state에서의 weight만 가지게 되므로 \(\hat{v}(s_i, \mathbf{w}) = \color{red}{w_i} \) 라는 관계를 얻을 수 있다. 그래서 그냥 feature vector도 weight도 각 feature 별로 가지는 값처럼 볼 수 있기 때문에 이것도 일종의 table이라고 가정해볼 수 있다. 이를 가지고 TD update를 다시 정리해보면 이렇게 된다.

$$ \begin{align} \mathbf{w} & \leftarrow \mathbf{w} + \alpha [ R_{t+1} + \gamma \hat{v}(S_{t+1}, \mathbf{w}) - \hat{v}(S_t, \mathbf{w})] \color{red}{\mathbf{x}(S_t)} \\ \color{red}{w_i} & \leftarrow \color{red}{w_i} + \alpha [ R_{t+1} + \gamma \hat{v}(S_{t+1}, \mathbf{w}) - \hat{v}(S_t, \mathbf{w})] \color{blue}{1} \end{align} $$

 그럼 TD update를 아래와 같은 형식으로 작성해보겠다. 여기서 \(\mathbf{x}(S_t) \approx \mathbf{x}_t \)라고 해보고자 한다.

$$ \begin{align} w_{t+1} & \doteq w_{t} + \alpha [ R_{t+1} + \gamma \hat{v}(S_{t+1}, \mathbf{w}_t) - \hat{v}(S_t, \mathbf{x}_t)] \mathbf{x}_t \\ & = \mathbf{w}_t + \alpha [R_{t+1} + \gamma \mathbf{w}^T_t \mathbf{x}_{t+1} - \mathbf{w}^T_t \mathbf{x}_t] \mathbf{x}_t \\ & = \mathbf{w}_t + \alpha [ \color{red}{ R_{t+1} \mathbf{x}_t} - \color{blue}{\mathbf{x}_t(\mathbf{x}_t - \gamma \mathbf{x}_{t+1})^T} \mathbf{w}_t \end{align} $$

 참고로 세번째 식은 \(\mathbf{x}_t\)를 괄호안에 넣고 transpose 행렬끼리 묶은 것이다. 

이제 TD update가 어떤 특정값으로 수렴하기 위해서는 적어도 weight의 변화가 적어야 할 것이다. 극단적으로 생각하면 weight의 변화량 (\(\Delta \mathbf{w}_t\))이 0이면 좋겠다. 일단 data가 여러개이므로 전체 weight 변화량의 기대값을 구하면

$$ \mathbb{E}[\Delta \mathbf{w}_t] = \mathbf{w}_{t+1} - \mathbf{w}_t = \alpha (\mathbf{b} - \mathbf{A} \mathbf{w}_t) $$

가 될 것이고, 위의 식과 하나씩 맞춰보면,

$$ \mathbf{b} \doteq \mathbb{E}[\color{red}{R_{t+1} \mathbf{x}_t}] \quad \mathbf{A} \doteq \mathbb{E}[\color{blue}{\mathbf{x}_t (\mathbf{x}_t - \gamma \mathbf{x}_{t+1})^T}] $$

로 정리할 수 있다. 이 값이 0이 되면 weight의 변화량이 없는 것이므로 수렴 조건이 되겠다. 위식을 보면 딱 0이 되는 weight의 조건이 하나가 나오게 되는데, 이때의 weight를 TD Fixed Point (\(\mathbf{w}_{TD}\)) 라고 표현한다. 만약 A행렬이 invertible(역행렬을 구할 수 있는 조건)이라면,

$$ \mathbf{w}_{TD} = \mathbf{A}^{-1} b $$

라는 조건을 만족하면, 해당 weight가 위의 linear function approximation의 해가 될 것이다.

그럼 이 TD fixed point에서의 value error와 실제 weight의 value error 중 가장 작은값간의 관계는 어떻게 될까? 책에서는 다음과 같이 표현하고 있다.

$$ \overline{VE}(\mathbf{w}_{TD}) \leq \frac{1}{1-\gamma} \min_{\mathbf{w}} \overline{VE}(\mathbf{w}) $$

 이 관계가 맞다고 가정하면 \(\gamma\)가 1에 가까울수록 둘간의 오차는 커질 것이고, 0이라면 오차가 작아지게 된다. 결국 오차의 크기에 따라 TD fixed point와 minimum value error의 해간의 대소관계가 결정되는 것이다. 또한 이렇게도 생각해볼 수 있다. 만약 Value Function approximation을 엄청 잘해서 어디서 계산하던 Value Error가 0이라면, 이때는 discount factor는 이런 error를 계산하는데 고려되지 않을 것이다.

댓글