-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path8 mst_ts_stock.py
50 lines (42 loc) · 1.58 KB
/
8 mst_ts_stock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
import networkx as nx
# Load stock data and choose MAOTAI
df = pd.read_csv('stockdata_adj.csv')
df = df[df.ts_code == '600519.SH']
# log returns
logret = np.log(df['close']).diff().dropna()
timeline = pd.to_datetime(df['trade_date'])
# Reshape logret to a 2D array
logret_2d = logret.values.reshape(-1, 1)
# Calculate distance matrix
DM = squareform(pdist(logret_2d))
# MST Clustering
KK = 20
Tree1 = nx.minimum_spanning_tree(nx.Graph(DM), weight='weight')
T1 = pd.DataFrame(Tree1.edges(data='weight'),
columns=['Source', 'Target', 'Weight'])
T1 = T1.sort_values('Weight')
Kset_MST = np.concatenate([[1], T1['Target'].iloc[-KK + 1:].unique()])
NodesStart = np.concatenate([[1], T1['Target'].iloc[-KK + 1:].unique()])
NodesEnd = np.concatenate([T1['Source'].iloc[-KK + 1:].unique(),
[T1['Target'].max()]])
# Color
cmap = plt.get_cmap('hsv')(np.linspace(0, 1, 20))
cmap[[1, 2, 3, 4, 17, 4, 12], :] = cmap[[4, 17, 12, 3, 4, 1, 2], :]
# Plot MST Clustering
plt.subplot(2, 1, 2)
plt.plot(timeline[NodesStart[0]:NodesEnd[0]],
logret[NodesStart[0]:NodesEnd[0]], color=cmap[0])
for i in range(1, len(NodesStart)):
plt.plot(timeline[NodesEnd[i - 1]:NodesEnd[i]],
logret[NodesEnd[i - 1]:NodesEnd[i]], color=cmap[i])
# plt.xlim([timeline[0], timeline[-1]])
plt.xlabel('Time')
plt.ylabel('Log Return')
plt.title('MST Clustering (Maotai)')
plt.tight_layout()
plt.savefig('mst_ts_stock.png', transparent=True, dpi=400)
plt.show()