self.observation_shape=(3,96,96)# Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
self.observation_shape=(1,12,12)# Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
self.action_space=list(range(7))# Fixed list of all possible actions. You should only edit the length
self.players=list(range(1))# List of players. You should only edit the length
self.stacked_observations=20# Number of previous observations and previous actions to add to the current observation
@@ -33,7 +35,7 @@ class MuZeroConfig:
self.num_workers=1# Number of simultaneous threads/workers self-playing to feed the replay buffer
self.max_moves=1000# Maximum number of moves if game is not finished before
self.num_simulations=100# Number of future moves self-simulated
self.num_simulations=50# Number of future moves self-simulated Chronological discount of the reward
self.temperature_threshold=None# Number of moves before dropping the temperature given by visit_softmax_temperature_fn to 0 (ie selecting the best action). If None, visit_softmax_temperature_fn is used every time
@@ -77,7 +79,7 @@ class MuZeroConfig:
self.save_model=True# Save the checkpoint in results_path as model.checkpoint
self.training_steps=100000# Total number of training steps (ie weights update according to a batch)
self.batch_size=128# Number of parts of games to train on at each training step
self.checkpoint_interval=100# Number of training steps before using the model for self-playing
self.checkpoint_interval=10# Number of training steps before using the model for self-playing
self.value_loss_weight=0.25# Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.train_on_gpu=torch.cuda.is_available()# Train on GPU if available
@@ -134,7 +136,7 @@ class Game(AbstractGame):
@@ -212,10 +214,23 @@ class Game(AbstractGame):
observation: the (96, 96, 3) numpy array of the observation to format
The corresponding (3, 96, 96) numpy array of the observation
The corresponding (1, 96, 96) numpy array of the observation