1、定义评估网络的输入输出
我们的评估网络输入是某个时刻的状态,输出是该状态下可以选择的每个动作的Q-eval。
我们的评估网络损失是该时刻下的目标Q值Q_target和当前Q值Q_eval的均方误差。
其中当前Q值Q_eval就是评估网络的输出,目标Q值需要根据Q-Learning机制求得,所以这里先定义为计算图的输入:
# 定义目标网络的输入输出 self.s = tf.placeholder(dtype="float", shape=[None, 80, 80, 4], name='s') # 网络的状态输入 self.q_target = tf.placeholder(dtype="float", shape=[None,self.n_actions], name='q_target')
我们的评估网络定义如下:
# 定义评估网络的输入输出 with tf.variable_scope('eval_net'): with tf.variable_scope('output'): self.q_eval = self._define_cnn_net(self.s,c_names=['eval_net_params', tf.GraphKeys.GLOBAL_VARIABLES]) with tf.variable_scope('loss'): self.cost = tf.reduce_mean(tf.squared_difference(self.q_target,self.q_eval))# 使用预测奖励值与当前位置的奖励平方差来获得损失值 tf.summary.scalar("loss", self.cost) # 使用TensorBoard监测该变量 with tf.variable_scope('train'): self.trainStep = tf.train.AdamOptimizer(self.learn_rate).minimize(self.cost)
2、定义目标网络的输入输出
因为我们的nature dqn算法 使用的是双网络机制,该机制降低了数据之间的关联,可以使算法对数据的Q学习更加健壮,所以我们需要定义一个目标网络.
该网络和评估网络结构一样,但是参数不需要反向传播更新,所以不需要定义损失值和训练操作:
我们的目标网络定义如下:
# 定义目标网络的输入输出 with tf.variable_scope('target_net'): with tf.variable_scope('output'): self.q_next = self._define_cnn_net(self.s, c_names=['target_net_params', tf.GraphKeys.GLOBAL_VARIABLES])
3、定义两个网络的参数更新操作
因为目标网络的参数是定期从当前网络中复制 而来,所以我们需要继续在计算图中定义参数更新操作:
t_params = tf.get_collection('target_net_params') e_params = tf.get_collection('eval_net_params') self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]二、定义学习机制
算法的学习机制基本就是三步操作:
1、从经验重放池中进行批采样
2、根据批采样数据使用单步Q-learning公式计算目标Q值
3、将目标Q值和状态等输入评估网络,训练更新评估网络和目标网络
1、从经验重放池中进行批采样
批采样代码和我们的经验重播池定义紧密相关,这里我们的批采样代码如下:
minibatch =self.memory.sample(self.batch_size) # 获得一个batch的图片信息 state_batch = [data[0] for data in minibatch] # 获得状态信息, [80, 80, 4] action_batch = [np.argmax(data[1]) for data in minibatch]# 获得动作信息,即向上的索引值为[1, 0], 向下的为[0, 1] reward_batch = [data[2] for data in minibatch]# 获取奖励信息 nextState_batch = [data[3] for data in minibatch]# 获得下一状态信息, [80, 80, 4] terminal_batch = [data[4] for data in minibatch]# 获得是否合格, [80, 80, 4]
2、根据批采样数据使用单步Q-learning公式计算目标Q值
(1)使用目标网络获取后继状态的后继Q值:
q_next_batch = self.q_next.eval(feed_dict={self.s: nextState_batch})
(2)根据单步Q-Learning公式计算批数据的目标Q值:
# 根据Q-Learning机制计算目标Q值 q_eval_batch = self.sess.run(self.q_eval, {self.s: state_batch}) # 使用评估网络获取当前状态的当前Q值 q_target_batch = q_eval_batch.copy() # 目标Q值和当前Q值具有相同的矩阵结构,所以直接复制 for i in range(0, self.batch_size): terminal = terminal_batch[i] if terminal: q_target_batch[i, action_batch[i]] = reward_batch[i] + self.gamma * np.max(q_next_batch[i]) # 目标Q值=当前奖励+折扣因子*后继Q值 else: q_target_batch[i, action_batch[i]] = reward_batch[i] # 目标Q值=当前奖励
3、将目标Q值和状态等输入评估网络,训练更新评估网络和目标网络
(1)训练更新评估网络:使用计算图直接执行反向传播和损失计算:
_, cost = self.sess.run([self.trainStep, self.cost], feed_dict={self.s: state_batch, self.q_target: q_target_batch})
(2)定期更新目标网络:使用计算图直接执行参数更新操作(硬更新机制):
if self.learn_step_counter % self.replace_target_iter == 0: self.sess.run(self.replace_target_op) print('\ntarget_net的参数被更新\n')
【手游开发】