preface

Writing experience replay pool is one of the necessary technologies of Deep Rl algorithm. The common one is array based. This paper lists three common implementation methods.

This article will not cover the code in detail, because it is too simple, so you can ask questions in the comments section if you don’t understand.

The first design approach is based on Numpy arrays

class ReplayBuffer(object): def __init__(self, capacity,state_dims): Self. data = np.zeros((capacity, Self. pointer = 0 # def store_transition(self, s, a, r, s_): If not hasattr(self, 'pointer'): Self. Pointer = 0 # Transition = np.hstack((s, [a,r], Self. data[index, self.data] self.data[index, self.data] :] = transition self.pointer += 1 def sample(self, batch_size): if self.capacity < self.pointer: batch_indexs = np.random.choice(self.capacity, size=batch_size) else: batch_indexs = np.random.choice(self.pointer, size=batch_size) #assert (self.pointer >= self.capacity, [batch_indexs, :] return self.data[batch_indexs, :Copy the code

Second design approach: Based on Python arrays

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
 
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = int((self.position + 1) % self.capacity)  # as a ring buffer
 
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))  # stack for each element
        return state, action, reward, next_state, done
 
    def __len__(self):
        return len(self.buffer)
Copy the code

Third design approach: queue based

This project uses queues to design, and its code is more concise:

from collections import deque import random class ReplayBuffer(object): def __init__(self, capacity): Self. data = 0 self.data = deque() # def store_transition(self, state,action,reward,state_,terminal): Self.data.append ((state, Action, reward, state_, terminal))# Add data if len(self.data) > self.memory_size: self.data.popleft() self.num -= 1 self.num += 1 def sample(self, batch_size): Minibatch = random. Sample (self.data, batch_size) return minibatch # Obtain n samplesCopy the code

[Page game development]