diff --git a/muzero.py b/muzero.py index 0489db0529c4cd6afa8c7fa79fce7365ddb5d3f4..c07982c15cb089c52e1802b508470e19ddb12c95 100644 --- a/muzero.py +++ b/muzero.py @@ -672,27 +672,14 @@ if __name__ == "__main__": # Parametrization documentation: https://facebookresearch.github.io/nevergrad/parametrization.html muzero.terminate_workers() del muzero - - # This part has custom edits - budget = 100 - parallel_experiments = 5 - num_tests = 20 - - parametrization = nevergrad.p.Dict( - lr_init=nevergrad.p.Log(lower=0.0001, upper=0.1), - discount=nevergrad.p.Log(lower=0.95, upper=0.9999), - stacked_observations=nevergrad.p.Log(lower=1, upper=100).set_integer_casting(), - num_simulations=nevergrad.p.Log(lower=1, upper=200).set_integer_casting(), - batch_size=nevergrad.p.Log(lower=10, upper=200).set_integer_casting(), - checkpoint_interval=nevergrad.p.Log(lower=5, upper=200).set_integer_casting(), - value_loss_weight=nevergrad.p.Log(lower=0.1, upper=0.5), - ) - + budget = 20 + parallel_experiments = 2 + lr_init = nevergrad.p.Log(a_min=0.0001, a_max=0.1) + discount = nevergrad.p.Log(lower=0.95, upper=0.9999) + parametrization = nevergrad.p.Dict(lr_init=lr_init, discount=discount) best_hyperparameters = hyperparameter_search( - game_name, parametrization, budget, parallel_experiments, num_tests + game_name, parametrization, budget, parallel_experiments, 20 ) - # End custom edits - muzero = MuZero(game_name, best_hyperparameters) else: break