-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsimulation.py
269 lines (216 loc) · 9.8 KB
/
simulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""This is the module for the primary simulation objects. Simulation is
for single simulations (with visualization) and BatchSimulation is for
experiments with multiple runs.
"""
import logging
import multiprocessing
from multiprocessing.pool import Pool
import sys
from copy import deepcopy
from threading import Thread
import numpy as np
import tqdm
import pygame # pylint: disable=E0401
import pygame.locals as pgl # pylint: disable=E0401
from aggregate_visualization import AggregatePlot
from custom_disease_model import DistemperModel, Kennels
from interventions import * # pylint: disable=W0401,W0614
class Simulation(object):
'''This is the primary simulation class.
It is responsible for both computation and rendering.
'''
def __init__(self, params,
spatial_visualization=True,
aggregate_visualization=True,
return_on_equillibrium=False):
self.return_on_equillibrium = return_on_equillibrium
self.spatial_visualization = spatial_visualization
self.aggregate_visualization = aggregate_visualization
if not self.spatial_visualization and \
not self.aggregate_visualization and \
not self.return_on_equillibrium:
#pylint: disable=W1201
logging.warning(('Warning: No visualizations were set, it is ' +
'highly recommended you set return_on_equillibrium ' +
'to True otherwise you will have to manually manage ' +
'the simulation state.'))
self.params = params
if 'infection_kernel_function' in self.params and \
isinstance(self.params['infection_kernel_function'], str):
self.params['infection_kernel_function'] = \
eval(self.params['infection_kernel_function']) #pylint: disable=W0123
elif not isinstance(self.params['infection_kernel_function'], object):
self.params['infection_kernel_function'] = lambda node, k: 0.0
if 'intervention' in self.params and isinstance(self.params['intervention'], str):
self.params['intervention'] = \
eval(self.params['intervention']) #pylint: disable=W0123
elif not isinstance(self.params['intervention'], object):
self.params['intervention'] = None
self.kennels = Kennels()
self.disease = DistemperModel(self.kennels.get_graph(), self.params)
self.update_hooks = []
if spatial_visualization:
self.fps = 0
self.screen_width, self.screen_height = 640, 480
pygame.init() #pylint: disable=E1101
self.fps_clock = pygame.time.Clock()
self.screen = pygame.display.set_mode((self.screen_width, self.screen_height), 0, 32)
self.surface = pygame.Surface(self.screen.get_size()) #pylint: disable=E1121
self.surface = self.surface.convert()
self.surface.fill((255, 255, 255))
self.clock = pygame.time.Clock()
pygame.key.set_repeat(1, 40)
self.screen.blit(self.surface, (0, 0))
self.font = pygame.font.Font(None, 36)
self.running = False
self.async_thread = None
if aggregate_visualization:
self.plt = AggregatePlot(self.disease, self.kennels)
self.update_hooks.append(self.plt.update)
@staticmethod
def copy(simulation):
'''Performs a deep copy of the simulation
Arguments:
simulation {Simulation} -- the simulation to copy
Returns:
Simulation -- the copy of the simulation
'''
new_simulation = Simulation(simulation.params, simulation.spatial_visualization, simulation.aggregate_visualization, simulation.return_on_equillibrium)
new_simulation.kennels = deepcopy(simulation.kennels)
new_simulation.disease = DistemperModel(new_simulation.kennels.get_graph(), new_simulation.params)
if simulation.aggregate_visualization:
new_simulation.plt = AggregatePlot(new_simulation.disease, new_simulation.kennels)
new_simulation.update_hooks.append(new_simulation.plt.update)
return new_simulation
def _check_events(self):
for event in pygame.event.get():
if event.type == pgl.QUIT: #pylint: disable=E1101
pygame.quit() #pylint: disable=E1101
sys.exit(0)
elif event.type == pgl.KEYDOWN: #pylint: disable=E1101
if event.key == pgl.K_ESCAPE: #pylint: disable=E1101
pygame.quit() #pylint: disable=E1101
sys.exit(0)
def _redraw(self):
self.screen.blit(self.surface, (0, 0))
pygame.display.flip()
pygame.display.update()
self.fps_clock.tick(self.fps)
def _draw_ui(self):
text = self.font.render('{0} days, {1} hours'.format(int(np.floor(self.disease.time/24.0)),
self.disease.time%24), 1, (10, 10, 10))
textpos = text.get_rect()
textpos.centerx = 200
self.surface.blit(text, textpos)
def _get_disease_state(self):
return {sc: len(self.disease.get_state_node(sc)['members']) for sc in self.disease.id_map}
def _get_disease_stats(self):
return {'E': self.disease.total_intake,
'S': 0,
'IS': self.disease.total_discharged,
'I': self.disease.total_infected,
'SY': 0,
'D': self.disease.total_died,
'E2I': self.disease.E2I,
'sum_S2D_IS2D': self.disease.sum_S2D_IS2D,
'E2S': self.disease.E2S,
'E2IS': self.disease.E2IS,
'S2I': self.disease.S2I
}
@staticmethod
def look_ahead(simulation, n=1, samples=1):
'''This function creates a copy of the simulation as it is right now then
iterates the simulation n times. It will perform this operation as many
times as specified by sample then provide the list of results.
Arguments:
n {int} -- the number of steps to look ahead
Keyword Arguments:
sample {int} -- the number of times to try looking ahead (default: {1})
Returns:
list(list(float)) -- a list of the results (total intake, total infected)
for each sample at time step t0+n where t0 is the current simulation state
'''
simulation_copy = Simulation.copy(simulation)
results = []
for _ in range(0, samples):
sim = Simulation.copy(simulation_copy)
for _ in range(0, n):
sim.update()
results.append([sim.disease.total_intake, sim.disease.total_infected])
return results
def update(self):
'''Update the simulation and redraw.
'''
if self.spatial_visualization:
self._check_events()
self.surface.fill(self.kennels.background_color)
if not self.disease.in_equilibrium() and not self.disease.end_conditions():
if 'intervention' in self.params and self.params['intervention'] is not None:
self.params['intervention'].update(simulation=self)
self.disease.update()
for hook in self.update_hooks:
hook()
elif self.return_on_equillibrium:
self.running = False
return
if self.spatial_visualization:
self.kennels.draw(self.surface, self.disease)
self._draw_ui()
self._redraw()
def stop(self):
'''Stop the simulation.
'''
self.running = False
def run(self, asynchronous=False):
'''Run the simulation (async creates a new thread).
Keyword Arguments:
asynchronous {bool} -- if True, a new thread is created (default: {False})
Returns:
list(int) -- a list of the final counts of the different states
'''
self.running = True
if asynchronous:
self.async_thread = Thread(target=self.run, args=(False,))
self.async_thread.start()
else:
while self.running:
self.update()
return self._get_disease_stats()
class BatchSimulation(object):
'''This class runs a batch version of the simulation.
'''
def __init__(self, params, runs, pool_size=-1):
self.params = params
self.runs = runs
if pool_size is None:
self.pool_size = 1
elif pool_size <= 0:
self.pool_size = multiprocessing.cpu_count()
def run(self):
'''This function runs the simulation asynchronously in multiple threads.
Returns:
list(list(int)) -- a list of all the results from the simulations
'''
results = []
with Pool(self.pool_size) as thread_pool:
tasks = tqdm.tqdm(thread_pool.imap_unordered(BatchSimulation.run_simulation,
[deepcopy(self.params) for _ in
range(0, self.runs)]),
total=self.runs)
for i in tasks:
results.append(i)
thread_pool.close()
thread_pool.join()
return results
@staticmethod
def run_simulation(params):
'''Run the simulation with a set of parameters
Arguments:
params {dict} -- the parameter dictionary for the simulation
Returns:
list(int) -- the simulation results
'''
return Simulation(params,
spatial_visualization=False,
aggregate_visualization=False,
return_on_equillibrium=True).run()