티스토리 뷰

(논문의 의도를 가져오되, 개인적인 의견이 담길 수도 있습니다.)

Off-Policy Deep Reinforcement Learning without Exploration - Fujimoto et al, ICML 2019 (논문, 코드)

요약

이 논문에서는 이미 모아져있는 고정된 dataset 상에서 강화학습 에이전트를 학습할 수 있는 알고리즘을 소개한다. 보통 강화학습은 exploration을 통해서 insight를 얻어내고, 이에 대한 경험으로 성능을 추출하는 형태로 되어 있지만, 고정된 dataset으로부터 학습하게 되면 exploration을 할 수 없기 때문에 성능을 얻어낼 요소가 부족하다. 이런 종류의 알고리즘을 Offline RL 혹은 Batch RL이라고 표현하고, 사실 이 알고리즘은 behavior policy와 target policy가 일치하지 않는 환경에 놓여있기 때문에 Off-policy 성격을 띄고 있다. 그래서 논문을 통해서 off-policy로 인해 발생하는 extrapolation error에 대해서 설명하고, 이를 극복할 수 있는 알고리즘인 BCQ (Batch-Constrained deep Q-Learning)에 대해서 소개하고자 한다.

전문

강화학습의 한계 중 하나는 data efficiency가 낮다는 점이고, 특히 데이터를 뽑는데 비용이 많이 들거나 시간이 많이 걸리거나, 혹은 위험한 task에서는 이를 극복해야 한다. 그래서 이를 대처하는 방안 중 하나가 바로 off-policy 기법인데, behavior policy와 target policy 가 다른 상태에서 behavior policy가 조심스럽게 학습된다면 앞에서 말한 한계를 어느정도 극복할 수 있고, 이런 전제가 있다면 imitation learning도 좋은 성능을 뽑을 수 있다. 하지만 대부분의 imitation learning 알고리즘은 suboptimal trajectory 상황에 놓이면 성능이 얻기 힘들고, 보통 이를 극복하기 위해서는 환경과의 interaction이 요구된다. 대신 batch RL을 사용하면 환경과의 interaction 없이도 고정된 dataset에서 학습할 수 있는 환경을 마련할 수 있게 된다. 

batch RL 기법은 experience replay dataset(Lin, 1992)을 활용한 것처럼 환경과의 interaction 없이도 에이전트를 학습시킬 수 있는데, 문제는 off-policy의 특성상 학습하는 데이터의 분포와 학습된 에이전트가 실제 환경에서 얻게될 데이터의 분포가 상관관계가 떨어질 수 있게 된다는 것이다. 이런 문제를 extrapolation error라고 표현했는데, 이 현상은 학습때 경험하지 못한 state-action pair가 나타날 경우 비현실적인 값으로 잘못 추정하는 것을 말한다. 당연한 이야기이겠지만 extrapolation error는 policy에 의해서 얻은 data의 분포와 batch에 포함되어 있는 data의 분포가 맞지 않음으로써 발생하고, 결국 batch에 포함되어 있지 않은 action을 선택하는 policy에 대한 value function을 학습할 수가 없게 된다.

이를 극복하기 위해서 policy에 의한 state-action visitation과 batch에 포함되어 있는 state-action pair간의 불일치를 최소화하면서 reward를 극대화할 수 있는, batch에 제약을 둔 알고리즘 (Batch-Constrained deep Q-Learning)를 소개했다. 이 알고리즘에서는 상태를 조건으로 둔 generative model (Variational AutoEncoder - VAE)을 사용해서 이전에 경험했던 action만 생성할 수 있도록 했다. generative model은 Q-network와 결함되어 있어서, batch안에 있는 data와 유사한 action 중 가장 큰 value를 지닌 값만 뽑을 수 있게 되어 있다. 

Extrapolation Error

논문에서는 Extrapolation Error가 발생할 수 있는 요인을 크게 세가지 요소로 잡고 있다.

  • Data의 부재
    : 만약 state-action pair \((s, a)\) 존재하지 않을 경우, 우리가 학습하고자 하는 \(\pi\)를 모사할 수 있는 충분한 데이터\((s', \pi(s'))\)없이는 이에 대한 \(Q_{\theta}(s', \pi(s'))\)의 추정치도 안좋게 되고 결국 approximation error가 발생하게 된다.
  • Model의 bias
    : off-policy Q-learning을 batch \(\mathcal{B}\)에서 수행할 경우, bellman operator \(\mathcal{T}^{\pi}\)은 \(\mathcal{B}\)에서 샘플링한 transition tuple \((s, a, r, s')\)에서 \(s'\)에 대한 기대값을 취한 것으로 근사된다.
    $$ \mathcal{T}^{\pi}Q(s, a) \approx \mathbb{E}_{s' \sim \mathcal{B}}[r + \gamma Q(s', \pi(s'))] $$
    그런데 만약 stochastic MDP일 경우, state-action visitation이 무한대로 이뤄지지 않는다면 실제 환경의 MDP보다는 \(\mathcal{B}\) 에 bias된 transition dynamics를 추정하게 된다.
  • Training Mismatch
    : 충분한 데이터가 있다고 하더라도 dataset에서는 transition이 uniform하게 샘플링되며, 이 때는 batch안의 데이터의 likelihood를 가중치로 사용한 loss로 사용한다.
    $$ \approx \sum_{(s, a, r, s') \in \mathcal{B}} \Vert r + \gamma Q_{\theta'}(s', \pi(s')) - Q_{\theta}(s, a) \Vert^2 $$
    만약 batch안의 data의 분포가 현재 policy의 data 분포를 따르지 않게 된다면, current policy에 의해서 선택된 action의 value function은 이와 같은 training mismatch에 의해서 매우 안좋게 추정되게 된다.

Batch-Constrained RL

DQN이나 DDPG같은 off-policy Deep RL 알고리즘들은 estimate의 정확성에 대한 고려를 하지 않으면 학습된 value estimate에 의해서 action을 선택하게 되면서 extrapolation error가 발생하게 되고, 결과적으로 out-of-distribution action이 비정상적으로 높은 값으로 extrapolate하는 현상이 나타난다. 그런데 만야 데이터가 가용한 영역에서는 off-policy agent도 비교적 정확하게 value를 추정할 수 있는데, 이때 간단한 아이디어를 적용해볼 수 있다. 바로 batch에서 유사한 state-action visitation만 추출할 수 있도록 policy를 생성해 extrapolation error를 없애자는 것이다. 이런 정책을 batch에 제약을 두었다(batch-constrained)고 표현했다. 주어진 batch에서 off-policy learning을 최적화하기 위해서, batch-constrained policy는 다음 세가지 목표를 가지고 action을 선택하도록 학습된다.

  1. batch안의 data에서 선택된 action간의 거리가 최소화되도록
  2. 익숙한 데이터가 관찰될 수 있도록 state를 유도함
  3. value function을 최대화하도록

이중 두번째와 세번째는 첫번째 목표가 만족되지 않는다면 달성할 수 없기 때문에, 첫번째 목표에 대한 중요성을 강조했다. 그리고 논문에서는 state-conditioned generative model을 통해서 batch상에서 유사한 action만 생성할 수 있도록 만들었다. 이 generative model은 Q-network와 맞물려서 이상적으로 생성된 action들로부터 가장 높은 value를 가지는 action을 선택할 수 있도록 했다. 결과적으로 Q-network을 학습하고, value update가 이뤄지는 동안 estimate의 최소값을 취하는 것이다. 이 과정을 통해서 덜 익숙한 state에 대해서는 penalty를 부여하고 data가 가지는 action을 선택할 수 있도록 policy를 유도하게 된다.

Batch-Constrained Deep RL

BCQ는 앞에서 언급한 것처럼 batch-constrained 개념을 generative model을 통해서 적용했다. 어떤 특정 state에서 BCQ는 batch안에서 가장 높은 유사도를 지니는 action들의 후보군을 뽑고, 학습된 Q-network를 통해서 가장 높은 value를 가지는 action을 선택한다. 여기에다가 동일 저자가 발표한 Clipped Double Q-learning (Fujimoto et al, 2018)을 적용해서 희귀하거나 이전에 경험하지 않은 사태에 대한 value를 penalize할 수 있도록 value estimate에 대한 bias를 부여했다. 결과적으로 BCQ는 batch안의 data와 유사한 state-action visitation을 가지는 policy를 학습하게 된다.

그리고 batch-constrained 개념을 유지하기 위해서 주어진 상태 \(s\)에 대한 어떤 가정을 정의했는데, 이 가정이란 batch \(\mathcal{B}\)안의 state-action pair와 \((s, a)\) 사이의 유사도는 학습된 state-conditioned marginal likelihood \(P_{\mathcal{B}}^{\mathcal{G}}(a \vert s)\)을 통해서 모델링할 수 있다는 것이다. 이 경우, \(P_{\mathcal{B}}^{\mathcal{G}}(a \vert s)\)을 최대화할 수 있도록 된 policy는 거리가 멀거나 경험하지 못한 state-action pair로부터 발생하는 extrapolation error를, 주어진 상태에 대해 batch내에서 가장 가까운 action만 선택하도록 함으로써 줄일 수 있게 되는 것이다. MuJoCo와 같이 고차원의 continuous 환경에서는 \(P_{\mathcal{B}}^{\mathcal{G}}(a \vert s)\)를 추정하기 어렵기 때문에, 대신 batch에 대한 parametric generative model \(G_{w}(s)\)를 학습해서 \(\arg\max P_{\mathcal{B}}^{\mathcal{G}}(a \vert s)\)라는 어느정도 합리적인 근사값으로부터 action을 선택할 수 있게 했다.

참고로 generative model로는 conditional VAE를 사용해서 내제되어 있는 latent space를 변환하여 분포를 모델했다. 이렇게 만든 generative model \(G_w\)와 value function \(Q_{\theta}\)를 엮어서 \(G_w\)로부터는 \(n\)개만큼의 action을 샘플링하고, 이에 대한 \(Q_{\theta}\)가 가장 높은 action을 뽑게끔 하는 policy로 활용할 수 있다. 대신 경험한 action에 대한 다양성을 높이기 위해서 perturbation model \(\xi_{\phi}(s, a \Phi)\)을 사용하는데, 이 모델은 action \(a\)를 \([-\Phi, \Phi]\) 범위 내에서 조절된 결과를 출력으로 내보낸다. 이를 통해서 generative model에 여러번 샘플링할 필요없이 제약된 범위 내에서의 action에 접근할 수 있게 된다. 이를 통해 아래와 같은 policy \(\pi\)를 뽑을 수 있다.

$$ \begin{align} \pi(s) =& \arg \max_{a_i +\xi_{\phi}(s, a_i, \Phi)} Q_{\theta}(s, a_i + \xi_{\phi}(s, a_i, \Phi)), \\ & \{a_i \sim G_w(s)\}_{i=1}^n \end{align} $$

여기서 \(n\)과 \(\Phi\)를 어떤값으로 선택하느냐에 따라서 imitation learning과 RL로 나눌 수 있다. 만약 \(\Phi = 0\)이고, \(n=1\)일 경우 policy는 behavioral cloning (BC)가 되고, \(\Phi \rightarrow a_{\max} - a_{\min}\) 이 될 경우 이 알고리즘은 Q-learning이 되어서 전체 action space에 대해 value function을 최대화하는 방향으로 policy가 형성된다.

perturbation model \(\xi_{\phi}\)는 \(a \sim G_w(s)\)을 샘플링 해서 policy gradient algorithm을 통해 \(Q_{\theta}(s, a)\)를 최대화하도록 학습된다.

$$ \phi \leftarrow \arg\max_{\phi} \sum_{(s, a) \in \mathcal{B}} Q_{\theta}(S, a + \xi_{\phi} (s, a, \Phi))$$

미래 state에 대한 불확실성에 대해서 penalty를 부여하기 위해서 앞에서 언급한 것처럼 Clipped Double Q-learning 을 사용했는데, 이 기법에서는 두 개의 Q network \( \{Q_{\theta_1}, Q_{\theta_2} \} \) 사이의 최소값으로 값을 추정하는 기법이다. 원래 이 방법은 overestimation bias를 측정하는 방법으로 쓰이긴 했지만, 이렇게 최소값을 취해서 불확실성 영역에 대한 high variance를 가지는 추정치에 대해 penalty를 줄 수 있고, policy가 batch에 들어있는 state로 이끌 수 있는 action을 선호하도록 유도할 수 있다. 조금 더 상세하게 설명하면, 두개의 값에 대한 convex combination을 취하는데, 최소값에 조금 더 큰 가중치를 부여하게끔 해서 두 Q network에 사용할 learning target을 만드는 것이다.

$$ r + \gamma \max_{a_i} [ \lambda \min_{j=1, 2} Q_{\theta'_j}(s', a_i) + (1 - \lambda) \max_{j=1, 2} Q_{\theta'_j}(s', a_i)] $$

여기서 \(a_i\)는 generative model에서 샘플링된 perturbed action을 나타낸다. 만약 여기에서 \(\lambda = 1\)로 두면, 이 방식은 clipped Double Q-learning이 된다.이렇게 최소값에 가중치를 가한 것을 사용해 본연의 greedy policy update에 비해서 overestimation bias를 덜 주게끔 제한된 update를 수행하고, 미래의 불확실성에 대한 중요도를 \(\lambda\)를 통해서 조절할 수 있게 해줬다.

이 모든 과정을 합쳐서 Batch-Constrained deep Q-learning (BCQ)를 구성했고, 전체적으로는 4개의 parameterized network으로 구성되어 있다.

  • generative model \(G_w(s)\)
  • perturbation model \(\xi_{\phi}(s, a)\)
  • two Q-network \(Q_{\theta_1}(s, a), Q_{\theta_2}(s, a)\)

그리고 이에 대한 알고리즘은 다음과 같다.

Batch-Constrained deep Q-learning (BCQ)

개인 의견

Offline RL 자체가 힘든 부분을 잘 설명한 것 같다. 아무래도 dataset이 고정되어 있다보니 exploration이 어렵고, 실제 데이터의 분포와 dataset의 분포가 다른 문제를 extrapolation error로 표현했다, 그래서 Q-network을 붙인 VAE를 학습시켜서 extrapolation error를 야기할 수 있는 action을 배제하게끔 샘플링해서 해결하고자 했다. 또한 저자가 이전에 발표한 clipped double Q-learning을 통해서 overestimation bias를 줄인 부분도 성능 개선에 도움이 된 듯 하다. 실험은 OpenAI MuJoCo 환경에서 했고, baseline을 DDPG와 DQN으로 둔 실험에서 BCQ가 좋은 성능을 보여주는 것으로 보인다.


추가(22.04.22): 사실 VAE를 쓴 이유는 continuous action space를 표현하기 위한 용도이고, 알고리즘에서도 action을 여러개 샘플링해서 perturbation model로 noise를 준 이유도 사실 DDPG에서의 Orstein-Uhlenbeck Process처럼 action에 대한 임의성을 부여하기 위한 용도이다. 저자의 github을 살펴보면 BCQ를 continuous_BCQ와 discrete_BCQ로 나눠서 구현했는데, 아마 논문의 target은 continuous_BCQ인 것 같다.

댓글