Skip to content
Snippets Groups Projects
Commit 87ff899a authored by Hurk,J.R. van den (Jason)'s avatar Hurk,J.R. van den (Jason)
Browse files

Tune parameters

parent 0a23dcce
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ class MuZeroConfig:
# 5 for gas, 5 for brakes; [0.0, 0.25, 0.5, 0.75, 1.0]
self.action_space = list(range(9 + 5 + 5)) # Fixed list of all possible actions. You should only edit the length
self.players = list(range(1)) # List of players. You should only edit the length
self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation
self.stacked_observations = 5 # Number of previous observations and previous actions to add to the current observation
# Evaluate
self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second)
......@@ -37,7 +37,7 @@ class MuZeroConfig:
self.selfplay_on_gpu = False
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of future moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.discount = 0.999 # Chronological discount of the reward
self.temperature_threshold = None # Number of moves before dropping the temperature given by visit_softmax_temperature_fn to 0 (ie selecting the best action). If None, visit_softmax_temperature_fn is used every time
# Root prior exploration noise
......@@ -56,11 +56,11 @@ class MuZeroConfig:
# Residual Network
self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture)
self.blocks = 1 # Number of blocks in the ResNet
self.channels = 2 # Number of channels in the ResNet
self.reduced_channels_reward = 2 # Number of channels in reward head
self.reduced_channels_value = 2 # Number of channels in value head
self.reduced_channels_policy = 2 # Number of channels in policy head
self.blocks = 2 # Number of blocks in the ResNet
self.channels = 16 # Number of channels in the ResNet
self.reduced_channels_reward = 16 # Number of channels in reward head
self.reduced_channels_value = 16 # Number of channels in value head
self.reduced_channels_policy = 16 # Number of channels in policy head
self.resnet_fc_reward_layers = [] # Define the hidden layers in the reward head of the dynamic network
self.resnet_fc_value_layers = [] # Define the hidden layers in the value head of the prediction network
self.resnet_fc_policy_layers = [] # Define the hidden layers in the policy head of the prediction network
......@@ -68,20 +68,20 @@ class MuZeroConfig:
# Fully Connected Network
self.encoding_size = 8
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [16] # Define the hidden layers in the value network
self.fc_policy_layers = [16] # Define the hidden layers in the policy network
self.fc_dynamics_layers = [64] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [64] # Define the hidden layers in the reward network
self.fc_value_layers = [64] # Define the hidden layers in the value network
self.fc_policy_layers = [64] # Define the hidden layers in the policy network
### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 100000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
self.value_loss_weight = 1 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available
self.optimizer = "Adam" # "Adam" or "SGD". Paper uses SGD
......@@ -96,7 +96,8 @@ class MuZeroConfig:
### Replay Buffer
self.replay_buffer_size = 500 # Number of self-play games to keep in the replay buffer
self.window_size = 1
self.replay_buffer_size = 2000 # Number of self-play games to keep in the replay buffer
self.num_unroll_steps = 10 # Number of game moves to keep for every batch element
self.td_steps = 50 # Number of steps in the future to take into account for calculating the target value
self.PER = True # Prioritized Replay (See paper appendix Training), select in priority the elements in the replay buffer which are unexpected for the network
......@@ -111,7 +112,7 @@ class MuZeroConfig:
### Adjust the self play / training ratio to avoid over/underfitting
self.self_play_delay = 0 # Number of seconds to wait after each played game
self.training_delay = 0 # Number of seconds to wait after each training step
self.ratio = 1.5 # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it
self.ratio = None # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it
def visit_softmax_temperature_fn(self, trained_steps):
......@@ -191,7 +192,6 @@ class Game(AbstractGame):
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")
def action_to_string(self, action_number):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment