前言
第1种设计方式:基于Numpy数组
书写经验重放池是Deep Rl算法的必备技术之一,常见的是基于数组的形式,本文列举3种常见的实现方式。
本文不会详细介绍代码,因为太过简单,不理解的同学可以直接在评论区提问。
第1种设计方式:基于Numpy数组
class ReplayBuffer(object): def __init__(self, capacity,state_dims): self.capacity = capacity # 经验池容量大小 self.data = np.zeros((capacity, state_dims* 2+2)) # 经验池存放的经验数据 self.pointer = 0 # 当前指针 def store_transition(self, s, a, r, s_): # 检查是否存在 if not hasattr(self, 'pointer'): self.pointer = 0 # 存储数据 transition = np.hstack((s, [a,r], s_)) # 按行连接 index = self.pointer % self.capacity # 如果超过该容量则自动从头开始 self.data[index, :] = transition self.pointer += 1 def sample(self, batch_size): if self.capacity < self.pointer: batch_indexs = np.random.choice(self.capacity, size=batch_size) else: batch_indexs = np.random.choice(self.pointer, size=batch_size) #assert (self.pointer >= self.capacity, '经验回放池还没有被装满') #print('经验回放池还没有被装满就开始采样') return self.data[batch_indexs, :] # 获取n个采样第2种设计方式:基于Python数组
class ReplayBuffer: def __init__(self, capacity): self.capacity = capacity self.buffer = [] self.position = 0 def push(self, state, action, reward, next_state, done): if len(self.buffer) < self.capacity: self.buffer.append(None) self.buffer[self.position] = (state, action, reward, next_state, done) self.position = int((self.position + 1) % self.capacity) # as a ring buffer def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element return state, action, reward, next_state, done def __len__(self): return len(self.buffer)第3种设计方式:基于队列
本项目使用队列来进行设计,其代码更加简洁:
from collections import deque import random class ReplayBuffer(object): def __init__(self, capacity): self.memory_size = capacity # 容量大小 self.num = 0 # 存放的经验数据数量 self.data = deque() # 存放经验数据的队列 def store_transition(self, state,action,reward,state_,terminal): self.data.append((state, action, reward, state_, terminal))# 添加数据 if len(self.data) > self.memory_size: self.data.popleft() self.num -= 1 self.num += 1 def sample(self, batch_size): minibatch = random.sample(self.data, batch_size) return minibatch # 获取n个采样
【页游开发】