前言

书写经验重放池是Deep Rl算法的必备技术之一,常见的是基于数组的形式,本文列举3种常见的实现方式

本文不会详细介绍代码,因为太过简单,不理解的同学可以直接在评论区提问。


第1种设计方式:基于Numpy数组
class ReplayBuffer(object):
    def __init__(self, capacity,state_dims):
        self.capacity = capacity # 经验池容量大小
        self.data = np.zeros((capacity, state_dims* 2+2))  # 经验池存放的经验数据
        self.pointer = 0 # 当前指针
    def store_transition(self, s, a, r, s_):
        # 检查是否存在
        if not hasattr(self, 'pointer'):
            self.pointer = 0
        # 存储数据
        transition = np.hstack((s, [a,r], s_))  # 按行连接
        index = self.pointer % self.capacity  # 如果超过该容量则自动从头开始
        self.data[index, :] = transition
        self.pointer += 1
    def sample(self, batch_size):
        if self.capacity < self.pointer:
            batch_indexs = np.random.choice(self.capacity, size=batch_size)
        else:
            batch_indexs = np.random.choice(self.pointer, size=batch_size)
            #assert (self.pointer >= self.capacity, '经验回放池还没有被装满')
            #print('经验回放池还没有被装满就开始采样')
        return self.data[batch_indexs, :]  # 获取n个采样
第2种设计方式:基于Python数组
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
 
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = int((self.position + 1) % self.capacity)  # as a ring buffer
 
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))  # stack for each element
        return state, action, reward, next_state, done
 
    def __len__(self):
        return len(self.buffer)
第3种设计方式:基于队列

本项目使用队列来进行设计,其代码更加简洁:

from collections import deque
import random
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.memory_size = capacity # 容量大小
        self.num = 0 # 存放的经验数据数量
        self.data = deque() # 存放经验数据的队列
 
    def store_transition(self, state,action,reward,state_,terminal):
        self.data.append((state, action, reward, state_, terminal))# 添加数据
        if len(self.data) > self.memory_size:
            self.data.popleft()
            self.num -= 1
        self.num += 1
 
    def sample(self, batch_size):
        minibatch = random.sample(self.data, batch_size)
        return minibatch  # 获取n个采样

页游开发】