Skip to content
Snippets Groups Projects
Commit c78ec6e3 authored by Mats Gottenbos's avatar Mats Gottenbos Committed by Hurk,J.R. van den (Jason)
Browse files

Fixed racecar observation shape

parent a6d89d1b
No related branches found
No related tags found
No related merge requests found
......@@ -18,8 +18,8 @@ class MuZeroConfig:
### Game
self.observation_shape = (1, 1, 4) # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
self.action_space = list(range(2)) # Fixed list of all possible actions. You should only edit the length
self.observation_shape = (3, 96, 96) # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
self.action_space = list(range(3)) # Fixed list of all possible actions. You should only edit the length
self.players = list(range(1)) # List of players. You should only edit the length
self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation
......@@ -147,8 +147,10 @@ class Game(AbstractGame):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
print('action', action)
observation, reward, done, _ = self.env.step(action)
return numpy.array([[observation]]), reward, done
observationFormatted = self.formatObservation(observation)
return observationFormatted, reward, done
def legal_actions(self):
"""
......@@ -170,7 +172,9 @@ class Game(AbstractGame):
Returns:
Initial observation of the game.
"""
return numpy.array([[self.env.reset()]])
initialObservation = numpy.array(self.env.reset())
initialObservationFormatted = self.formatObservation(initialObservation)
return initialObservationFormatted
def close(self):
"""
......@@ -200,3 +204,16 @@ class Game(AbstractGame):
1: "Push cart to the right",
}
return f"{action_number}. {actions[action_number]}"
def formatObservation(self, observation):
"""
Formats a game observation to the correct format
Args:
observation: the (96, 96, 3) numpy array of the observation to format
Returns:
The corresponding (3, 96, 96) numpy array of the observation
"""
return numpy.transpose(observation, [2, 0, 1])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment