DQN 算法敲开了深度强化学习的大门,但是作为先驱性的工作,其本身存在着一些问题以及一些可以改进的地方。于是,在 DQN 之后,学术界涌现出了非常多的改进算法。本章将介绍其中两个非常著名的算法:Double DQN 和 Dueling DQN,这两个算法的实现非常简单,只需要在 DQN 的基础上稍加修改,它们能在一定程度上改善 DQN 的效果。如果读者想要了解更多、更详细的 DQN 改进方法,可以阅读 Rainbow 模型的论文及其引用文献。
普通的 DQN 算法通常会导致对
其中
换句话说,
为了解决这一问题,Double DQN 算法提出利用两个独立训练的神经网络估算
在传统的 DQN 算法中,本来就存在两套
显然,DQN 与 Double DQN 的差别只是在于计算状态
所以 Double DQN 的代码实现可以直接在 DQN 的基础上进行,无须做过多修改。
本节采用的环境是倒立摆(Inverted Pendulum),该环境下有一个处于随机位置的倒立摆,如图 8-1 所示。环境的状态包括倒立摆角度的正弦值
标号 | 名称 | 最小值 | 最大值 |
---|---|---|---|
0 | -1.0 | 1.0 | |
1 | -1.0 | 1.0 | |
2 | -8.0 | 8.0 |
标号 | 动作 | 最小值 | 最大值 |
---|---|---|---|
0 | 力矩 | -2.0 | 2.0 |
力矩大小是在
import randomimport gymimport numpy as npimport torchimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport rl_utilsfrom tqdm import tqdmclass Qnet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(Qnet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)
接下来我们在 DQN 代码的基础上稍做修改以实现 Double DQN。
class DQN:''' DQN算法,包括Double DQN '''def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device,dqn_type='VanillaDQN'):self.action_dim = action_dimself.q_net = Qnet(state_dim, hidden_dim, self.action_dim).to(device)self.target_q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = gammaself.epsilon = epsilonself.target_update = target_updateself.count = 0self.dqn_type = dqn_typeself.device = devicedef take_action(self, state):if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.q_net(state).argmax().item()return actiondef max_q_value(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)return self.q_net(state).max().item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)q_values = self.q_net(states).gather(1, actions) # Q值# 下个状态的最大Q值if self.dqn_type == 'DoubleDQN': # DQN与Double DQN的区别max_action = self.q_net(next_states).max(1)[1].view(-1, 1)max_next_q_values = self.target_q_net(next_states).gather(1, max_action)else: # DQN的情况max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) # TD误差目标dqn_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数self.optimizer.zero_grad() # PyTorch中默认梯度会累积,这里需要显式将梯度置为0dqn_loss.backward() # 反向传播更新参数self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict()) # 更新目标网络self.count += 1
接下来我们设置相应的超参数,并实现将倒立摆环境中的连续动作转化为离散动作的函数。
lr = 1e-2num_episodes = 200hidden_dim = 128gamma = 0.98epsilon = 0.01target_update = 50buffer_size = 5000minimal_size = 1000batch_size = 64device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'Pendulum-v0'env = gym.make(env_name)state_dim = env.observation_space.shape[0]action_dim = 11 # 将连续动作分成11个离散动作def dis_to_con(discrete_action, env, action_dim): # 离散动作转回连续的函数action_lowbound = env.action_space.low[0] # 连续动作的最小值action_upbound = env.action_space.high[0] # 连续动作的最大值return action_lowbound + (discrete_action /(action_dim - 1)) * (action_upbound -action_lowbound)
接下来要对比 DQN 和 Double DQN 的训练情况,为了便于后续多次调用,我们进一步将 DQN 算法的训练过程定义成一个函数。训练过程会记录下每个状态的最大
def train_DQN(agent, env, num_episodes, replay_buffer, minimal_size,batch_size):return_list = []max_q_value_list = []max_q_value = 0for i in range(10):with tqdm(total=int(num_episodes / 10),desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)max_q_value = agent.max_q_value(state) * 0.005 + max_q_value * 0.995 # 平滑处理max_q_value_list.append(max_q_value) # 保存每个状态的最大Q值action_continuous = dis_to_con(action, env,agent.action_dim)next_state, reward, done, _ = env.step([action_continuous])replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_list, max_q_value_list
一切就绪!我们首先训练 DQN 并打印出其学习过程中最大
random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = rl_utils.ReplayBuffer(buffer_size)agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list, max_q_value_list = train_DQN(agent, env, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))mv_return = rl_utils.moving_average(return_list, 5)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN on {}'.format(env_name))plt.show()frames_list = list(range(len(max_q_value_list)))plt.plot(frames_list, max_q_value_list)plt.axhline(0, c='orange', ls='--')plt.axhline(10, c='red', ls='--')plt.xlabel('Frames')plt.ylabel('Q value')plt.title('DQN on {}'.format(env_name))plt.show()
Iteration 0: 100%|██████████| 20/20 [00:02<00:00, 7.14it/s, episode=20, return=-1018.764]Iteration 1: 100%|██████████| 20/20 [00:03<00:00, 5.73it/s, episode=40, return=-463.311]Iteration 2: 100%|██████████| 20/20 [00:03<00:00, 5.53it/s, episode=60, return=-184.817]Iteration 3: 100%|██████████| 20/20 [00:03<00:00, 5.55it/s, episode=80, return=-317.366]Iteration 4: 100%|██████████| 20/20 [00:03<00:00, 5.67it/s, episode=100, return=-208.929]Iteration 5: 100%|██████████| 20/20 [00:03<00:00, 5.59it/s, episode=120, return=-182.659]Iteration 6: 100%|██████████| 20/20 [00:03<00:00, 5.25it/s, episode=140, return=-275.938]Iteration 7: 100%|██████████| 20/20 [00:03<00:00, 5.65it/s, episode=160, return=-209.702]Iteration 8: 100%|██████████| 20/20 [00:03<00:00, 5.73it/s, episode=180, return=-246.861]Iteration 9: 100%|██████████| 20/20 [00:03<00:00, 5.77it/s, episode=200, return=-293.374]
根据代码运行结果我们可以发现,DQN 算法在倒立摆环境中能取得不错的回报,最后的期望回报在-200 左右,但是不少
random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = rl_utils.ReplayBuffer(buffer_size)agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device, 'DoubleDQN')return_list, max_q_value_list = train_DQN(agent, env, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))mv_return = rl_utils.moving_average(return_list, 5)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('Double DQN on {}'.format(env_name))plt.show()frames_list = list(range(len(max_q_value_list)))plt.plot(frames_list, max_q_value_list)plt.axhline(0, c='orange', ls='--')plt.axhline(10, c='red', ls='--')plt.xlabel('Frames')plt.ylabel('Q value')plt.title('Double DQN on {}'.format(env_name))plt.show()
Iteration 0: 100%|██████████| 20/20 [00:03<00:00, 6.60it/s, episode=20, return=-818.719]Iteration 1: 100%|██████████| 20/20 [00:03<00:00, 5.43it/s, episode=40, return=-391.392]Iteration 2: 100%|██████████| 20/20 [00:03<00:00, 5.29it/s, episode=60, return=-216.078]Iteration 3: 100%|██████████| 20/20 [00:03<00:00, 5.52it/s, episode=80, return=-438.220]Iteration 4: 100%|██████████| 20/20 [00:03<00:00, 5.42it/s, episode=100, return=-162.128]Iteration 5: 100%|██████████| 20/20 [00:03<00:00, 5.50it/s, episode=120, return=-389.088]Iteration 6: 100%|██████████| 20/20 [00:03<00:00, 5.44it/s, episode=140, return=-273.700]Iteration 7: 100%|██████████| 20/20 [00:03<00:00, 5.23it/s, episode=160, return=-221.605]Iteration 8: 100%|██████████| 20/20 [00:04<00:00, 4.91it/s, episode=180, return=-262.134]Iteration 9: 100%|██████████| 20/20 [00:03<00:00, 5.34it/s, episode=200, return=-278.752]
我们可以发现,与普通的 DQN 相比,Double DQN 比较少出现
Dueling DQN 是 DQN 另一种的改进算法,它在传统 DQN 的基础上只进行了微小的改动,但却能大幅提升 DQN 的表现。在强化学习中,我们将状态动作价值函数
其中,
将状态价值函数和优势函数分别建模的好处在于:某些情境下智能体只会关注状态的价值,而并不关心不同动作导致的差异,此时将二者分开建模能够使智能体更好地处理与动作关联较小的状态。在图 8-3 所示的驾驶车辆游戏中,智能体注意力集中的部位被显示为橙色(另见彩插图 4),当智能体前面没有车时,车辆自身动作并没有太大差异,此时智能体更关注状态价值,而当智能体前面有车时(智能体需要超车),智能体开始关注不同动作优势值的差异。
对于 Dueling DQN 中的公式
此时
此时
有的读者可能会问:“为什么 Dueling DQN 会比 DQN 好?”部分原因在于 Dueling DQN 能更高效学习状态价值函数。每一次更新时,函数
Dueling DQN 与 DQN 相比的差异只是在网络结构上,大部分代码依然可以继续沿用。我们定义状态价值函数和优势函数的复合神经网络VAnet
。
class VAnet(torch.nn.Module):''' 只有一层隐藏层的A网络和V网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(VAnet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim) # 共享网络部分self.fc_A = torch.nn.Linear(hidden_dim, action_dim)self.fc_V = torch.nn.Linear(hidden_dim, 1)def forward(self, x):A = self.fc_A(F.relu(self.fc1(x)))V = self.fc_V(F.relu(self.fc1(x)))Q = V + A - A.mean(1).view(-1, 1) # Q值由V值和A值计算得到return Qclass DQN:''' DQN算法,包括Double DQN和Dueling DQN '''def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device,dqn_type='VanillaDQN'):self.action_dim = action_dimif dqn_type == 'DuelingDQN': # Dueling DQN采取不一样的网络框架self.q_net = VAnet(state_dim, hidden_dim,self.action_dim).to(device)self.target_q_net = VAnet(state_dim, hidden_dim,self.action_dim).to(device)else:self.q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)self.target_q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = gammaself.epsilon = epsilonself.target_update = target_updateself.count = 0self.dqn_type = dqn_typeself.device = devicedef take_action(self, state):if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.q_net(state).argmax().item()return actiondef max_q_value(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)return self.q_net(state).max().item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)q_values = self.q_net(states).gather(1, actions)if self.dqn_type == 'DoubleDQN':max_action = self.q_net(next_states).max(1)[1].view(-1, 1)max_next_q_values = self.target_q_net(next_states).gather(1, max_action)else:max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))self.optimizer.zero_grad()dqn_loss.backward()self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())self.count += 1random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = rl_utils.ReplayBuffer(buffer_size)agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device, 'DuelingDQN')return_list, max_q_value_list = train_DQN(agent, env, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))mv_return = rl_utils.moving_average(return_list, 5)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('Dueling DQN on {}'.format(env_name))plt.show()frames_list = list(range(len(max_q_value_list)))plt.plot(frames_list, max_q_value_list)plt.axhline(0, c='orange', ls='--')plt.axhline(10, c='red', ls='--')plt.xlabel('Frames')plt.ylabel('Q value')plt.title('Dueling DQN on {}'.format(env_name))plt.show()
Iteration 0: 100%|████████████████████████████████████████| 20/20 [00:10<00:00, 1.87it/s, episode=20, return=-708.652]Iteration 1: 100%|████████████████████████████████████████| 20/20 [00:15<00:00, 1.28it/s, episode=40, return=-229.557]Iteration 2: 100%|████████████████████████████████████████| 20/20 [00:15<00:00, 1.32it/s, episode=60, return=-184.607]Iteration 3: 100%|████████████████████████████████████████| 20/20 [00:13<00:00, 1.50it/s, episode=80, return=-200.323]Iteration 4: 100%|███████████████████████████████████████| 20/20 [00:13<00:00, 1.51it/s, episode=100, return=-213.811]Iteration 5: 100%|███████████████████████████████████████| 20/20 [00:13<00:00, 1.53it/s, episode=120, return=-181.165]Iteration 6: 100%|███████████████████████████████████████| 20/20 [00:14<00:00, 1.35it/s, episode=140, return=-222.040]Iteration 7: 100%|███████████████████████████████████████| 20/20 [00:14<00:00, 1.35it/s, episode=160, return=-173.313]Iteration 8: 100%|███████████████████████████████████████| 20/20 [00:12<00:00, 1.62it/s, episode=180, return=-236.372]Iteration 9: 100%|███████████████████████████████████████| 20/20 [00:12<00:00, 1.57it/s, episode=200, return=-230.058]
根据代码运行结果我们可以发现,相比于传统的 DQN,Dueling DQN 在多个动作选择下的学习更加稳定,得到的回报最大值也更大。由 Dueling DQN 的原理可知,随着动作空间的增大,Dueling DQN 相比于 DQN 的优势更为明显。之前我们在环境中设置的离散动作数为 11,我们可以增加离散动作数(例如 15、25 等),继续进行对比实验。
在传统的 DQN 基础上,有两种非常容易实现的变式——Double DQN 和 Dueling DQN,Double DQN 解决了 DQN 中对
我们可以对
即动作空间
证明:将估算误差记为
因此,我们得到关于
最后我们可以得到:
虽然这一分析简化了实际环境,但它仍然正确刻画了
[1] HASSELT V H, GUEZ A, SILVER D. Deep reinforcement learning with double q-learning [C]// Proceedings of the AAAI conference on artificial intelligence. 2016, 30(1).
[2] WANG Z, SCHAUL T, HESSEL M, et al. Dueling network architectures for deep reinforcement learning [C]// International conference on machine learning, PMLR, 2016: 1995-2003.
[3] HESSEL M, MODAYIL J, HASSELT V H, et al. Rainbow: Combining improvements in deep reinforcement learning [C]// Thirty-second AAAI conference on artificial intelligence, 2018.