之前的章节提到过在线策略算法的采样效率比较低,我们通常更倾向于使用离线策略算法。然而,虽然 DDPG 是离线策略算法,但是它的训练非常不稳定,收敛性较差,对超参数比较敏感,也难以适应不同的复杂环境。2018 年,一个更加稳定的离线策略算法 Soft Actor-Critic(SAC)被提出。SAC 的前身是 Soft Q-learning,它们都属于最大熵强化学习的范畴。Soft Q-learning 不存在一个显式的策略函数,而是使用一个函数
熵(entropy)表示对一个随机变量的随机程度的度量。具体而言,如果
在强化学习中,我们可以使用
最大熵强化学习(maximum entropy RL)的思想就是除了要最大化累积奖励,还要使得策略更加随机。如此,强化学习的目标中就加入了一项熵的正则项,定义为
其中,
熵正则化增加了强化学习算法的探索程度,
在最大熵强化学习框架中,由于目标函数发生了变化,其他的一些定义也有相应的变化。首先,我们看一下 Soft 贝尔曼方程:
其中,状态价值函数被写为
于是,根据该 Soft 贝尔曼方程,在有限的状态和动作空间情况下,Soft 策略评估可以收敛到策略
重复交替使用 Soft 策略评估和 Soft 策略提升,最终策略可以收敛到最大熵强化学习目标中的最优策略。但该 Soft 策略迭代方法只适用于表格型(tabular)设置的情况,即状态空间和动作空间是有限的情况。在连续空间下,我们需要通过参数化函数
在 SAC 算法中,我们为两个动作价值函数
其中,
策略
可以理解为最大化函数
对连续动作空间的环境,SAC 算法的策略输出高斯分布的均值和标准差,但是根据高斯分布来采样动作的过程是不可导的。因此,我们需要用到重参数化技巧(reparameterization trick)。重参数化的做法是先从一个单位高斯分布
在 SAC 算法中,如何选择熵正则项的系数非常重要。在不同的状态下需要不同大小的熵:在最优动作不确定的某个状态下,熵的取值应该大一点;而在某个最优动作比较确定的状态下,熵的取值可以小一点。为了自动调整熵正则项,SAC 将强化学习的目标改写为一个带约束的优化问题:
也就是最大化期望回报,同时约束熵的均值大于
即当策略的熵低于目标值
至此,我们介绍完了 SAC 算法的整体思想,它的具体算法流程如下:
我们来看一下 SAC 的代码实现,首先在倒立摆环境下进行实验,然后再尝试将 SAC 应用到与离散动作交互的车杆环境。
首先我们导入需要用到的库。
import randomimport gymimport numpy as npfrom tqdm import tqdmimport torchimport torch.nn.functional as Ffrom torch.distributions import Normalimport matplotlib.pyplot as pltimport rl_utils
接下来定义策略网络和价值网络。由于处理的是与连续动作交互的环境,策略网络输出一个高斯分布的均值和标准差来表示动作分布;而价值网络的输入是状态和动作的拼接向量,输出一个实数来表示动作价值。
class PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))mu = self.fc_mu(x)std = F.softplus(self.fc_std(x))dist = Normal(mu, std)normal_sample = dist.rsample() # rsample()是重参数化采样log_prob = dist.log_prob(normal_sample)action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)
然后我们来看一下 SAC 算法的主要代码。如 14.4 节所述,SAC 使用两个 Critic 网络
class SACContinuous:''' 处理连续动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,action_bound).to(device) # 策略网络self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device) # 第一个Q网络self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device) # 第二个Q网络self.target_critic_1 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device) # 第一个目标Q网络self.target_critic_2 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device) # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.actor(state)[0]return [action.item()]def calc_target(self, rewards, next_states, dones): # 计算目标Q值next_actions, log_prob = self.actor(next_states)entropy = -log_probq1_value = self.target_critic_1(next_states, next_actions)q2_value = self.target_critic_2(next_states, next_actions)next_value = torch.min(q1_value,q2_value) + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练rewards = (rewards + 8.0) / 8.0# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络new_actions, log_prob = self.actor(states)entropy = -log_probq1_value = self.critic_1(states, new_actions)q2_value = self.critic_2(states, new_actions)actor_loss = torch.mean(-self.log_alpha.exp() * entropy -torch.min(q1_value, q2_value))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)
接下来我们就在倒立摆环境上尝试一下 SAC 算法吧!
env_name = 'Pendulum-v0'env = gym.make(env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]action_bound = env.action_space.high[0] # 动作最大值random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)actor_lr = 3e-4critic_lr = 3e-3alpha_lr = 3e-4num_episodes = 100hidden_dim = 128gamma = 0.99tau = 0.005 # 软更新参数buffer_size = 100000minimal_size = 1000batch_size = 64target_entropy = -env.action_space.shape[0]device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")replay_buffer = rl_utils.ReplayBuffer(buffer_size)agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau,gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)
Iteration 0: 0%| | 0/10 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.)Iteration 0: 100%|██████████| 10/10 [00:09<00:00, 1.03it/s, episode=10, return=-1534.655]Iteration 1: 100%|██████████| 10/10 [00:18<00:00, 1.83s/it, episode=20, return=-1085.715]Iteration 2: 100%|██████████| 10/10 [00:15<00:00, 1.60s/it, episode=30, return=-364.507]Iteration 3: 100%|██████████| 10/10 [00:13<00:00, 1.37s/it, episode=40, return=-222.485]Iteration 4: 100%|██████████| 10/10 [00:13<00:00, 1.36s/it, episode=50, return=-157.978]Iteration 5: 100%|██████████| 10/10 [00:13<00:00, 1.37s/it, episode=60, return=-166.056]Iteration 6: 100%|██████████| 10/10 [00:13<00:00, 1.38s/it, episode=70, return=-143.147]Iteration 7: 100%|██████████| 10/10 [00:13<00:00, 1.37s/it, episode=80, return=-127.939]Iteration 8: 100%|██████████| 10/10 [00:14<00:00, 1.42s/it, episode=90, return=-180.905]Iteration 9: 100%|██████████| 10/10 [00:14<00:00, 1.41s/it, episode=100, return=-171.265]
episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('SAC on {}'.format(env_name))plt.show()mv_return = rl_utils.moving_average(return_list, 9)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('SAC on {}'.format(env_name))plt.show()
可以发现,SAC 在倒立摆环境中的表现非常出色。SAC 算法原本是针对连续动作交互的环境提出的,那一个比较自然的问题便是:SAC 能否处理与离散动作交互的环境呢?答案是肯定的,但是我们要做一些相应的修改。首先,策略网络和价值网络的网络结构将发生如下改变:
class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class QValueNet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)
该策略网络输出一个离散的动作分布,所以在价值网络的学习过程中,不需要再对下一个动作
class SAC:''' 处理离散动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, tau, gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 第一个Q网络self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 第二个Q网络self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device) # 第一个目标Q网络self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device) # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()# 计算目标Q值,直接用策略网络的输出概率进行期望计算def calc_target(self, rewards, next_states, dones):next_probs = self.actor(next_states)next_log_probs = torch.log(next_probs + 1e-8)entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)q1_value = self.target_critic_1(next_states)q2_value = self.target_critic_2(next_states)min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)next_value = min_qvalue + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device) # 动作不再是float类型rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_q_values = self.critic_1(states).gather(1, actions)critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))critic_2_q_values = self.critic_2(states).gather(1, actions)critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络probs = self.actor(states)log_probs = torch.log(probs + 1e-8)# 直接根据概率计算熵entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True) #q1_value = self.critic_1(states)q2_value = self.critic_2(states)min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True) # 直接根据概率计算期望actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)
actor_lr = 1e-3critic_lr = 1e-2alpha_lr = 1e-2num_episodes = 200hidden_dim = 128gamma = 0.98tau = 0.005 # 软更新参数buffer_size = 10000minimal_size = 500batch_size = 64target_entropy = -1device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'env = gym.make(env_name)random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = rl_utils.ReplayBuffer(buffer_size)state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, tau, gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)
Iteration 0: 100%|██████████| 20/20 [00:00<00:00, 148.74it/s, episode=20, return=19.700]Iteration 1: 100%|██████████| 20/20 [00:00<00:00, 28.35it/s, episode=40, return=10.600]Iteration 2: 100%|██████████| 20/20 [00:00<00:00, 24.96it/s, episode=60, return=10.000]Iteration 3: 100%|██████████| 20/20 [00:00<00:00, 24.87it/s, episode=80, return=9.800]Iteration 4: 100%|██████████| 20/20 [00:00<00:00, 26.33it/s, episode=100, return=9.100]Iteration 5: 100%|██████████| 20/20 [00:00<00:00, 26.30it/s, episode=120, return=9.500]Iteration 6: 100%|██████████| 20/20 [00:09<00:00, 2.19it/s, episode=140, return=178.400]Iteration 7: 100%|██████████| 20/20 [00:15<00:00, 1.30it/s, episode=160, return=200.000]Iteration 8: 100%|██████████| 20/20 [00:15<00:00, 1.30it/s, episode=180, return=200.000]Iteration 9: 100%|██████████| 20/20 [00:15<00:00, 1.29it/s, episode=200, return=197.600]
episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('SAC on {}'.format(env_name))plt.show()mv_return = rl_utils.moving_average(return_list, 9)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('SAC on {}'.format(env_name))plt.show()
可以发现,SAC 在离散动作环境车杆下具有完美的收敛性能,并且其策略回报的曲线十分稳定,这体现出 SAC 可以在离散动作环境下平衡探索与利用的优秀性质。
本章首先讲解了什么是最大熵强化学习,并通过控制策略所采取动作的熵来调整探索与利用的平衡,可以帮助读者加深对探索与利用的关系的理解;然后讲解了 SAC 算法,剖析了它背后的原理以及具体的流程,最后在连续的倒立摆环境以及离散的车杆环境中进行了 SAC 算法的代码实践。 由于有扎实的理论基础和优秀的实验性能,SAC 算法已经成为炙手可热的深度强化学习算法,很多新的研究基于 SAC 算法,第 17 章将要介绍的基于模型的强化学习算法 MBPO 和第 18 章将要介绍的离线强化学习算法 CQL 就是以 SAC 作为基本模块构建的。
[1] HAARNOJA T, ZHOU A, ABBEEL P,et al. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor [C] // International conference on machine learning, PMLR, 2018:1861-1870.
[2] HAARNOJA T, ZHOU A, HARTIKAINEN K, et al. Soft actor-critic algorithms and applications [J]. 2018.
[3] HAARNOJA T, TANG H, ABBEEL P,et al. Reinforcement learning with deep energy-based policies [C] // International conference on machine learning, PMLR, 2017:1352-1361.
[4] SCHULMAN J, CHEN X, ABBEEL P. Equivalence between policy gradients and soft q-learning [J]. 2017.