2015-07-10 78 views
1

我想添加一種「球體」到我的數據集羣。我想添加一個「球體」到我的數據集羣

我的數據集羣是這樣的,它沒有 「」 球體」。

enter image description here

這是我的代碼

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import style 
style.use('ggplot') 
import pandas as pd 
from sklearn.cluster import KMeans 

MY_FILE='total_watt.csv' 
date = [] 
consumption = [] 

df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0]) 
df = df.resample('1D', how='sum') 
df = df.dropna() 

date = df.index.tolist() 
date = [x.strftime('%Y-%m-%d') for x in date] 
from sklearn.preprocessing import LabelEncoder 

encoder = LabelEncoder() 
date_numeric = encoder.fit_transform(date) 
consumption = df[df.columns[0]].values 

X = np.array([date_numeric, consumption]).T 

kmeans = KMeans(n_clusters=3) 
kmeans.fit(X) 

centroids = kmeans.cluster_centers_ 
labels = kmeans.labels_ 

print(centroids) 
print(labels) 

fig, ax = plt.subplots(figsize=(10,8)) 
rect = fig.patch 
rect.set_facecolor('#2D2B2B') 



colors = ["b.","r.","g."] 

for i in range(len(X)): 
    print("coordinate:",encoder.inverse_transform(X[i,0].astype(int)), X[i,1], "label:", labels[i]) 
    ax.plot(X[i][0], X[i][1], colors[labels[i]], markersize = 10) 
ax.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=150, linewidths = 5, zorder = 10) 
a = np.arange(0, len(X), 5) 
ax.set_xticks(a) 
ax.set_xticklabels(encoder.inverse_transform(a.astype(int))) 
ax.tick_params(axis='x', colors='lightseagreen') 
ax.tick_params(axis='y', colors='lightseagreen') 
plt.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=100, c="black", linewidths = 5, zorder = 10) 
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold') 
ax.set_xlabel('time', color='gold') 
ax.set_ylabel('date(year 2011)', color='gold') 


plt.show() 

「球體」 是周圍環境的情節(集聚區),如圖所示。

enter image description here

我試圖谷歌它。

但是當我鍵入「matplotlib球」,我不能得到任何結果..

回答

1

在您的文章中示例圖表看起來像Generalized Gaussian Mixture導致每個球是高斯2-d密度。

我會立即寫一個示例代碼來演示如何在數據集上使用GMM並進行這種繪圖。

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import style 
style.use('ggplot') 
import pandas as pd 
# code changes here 
# =========================================== 
from sklearn.mixture import GMM 
# =========================================== 
from sklearn.preprocessing import LabelEncoder 

# replace it with you file path 
MY_FILE='/home/Jian/Downloads/total_watt.csv' 

df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0]) 
df = df.resample('1D', how='sum') 
df = df.dropna() 

date = df.index.tolist() 
date = [x.strftime('%Y-%m-%d') for x in date] 

encoder = LabelEncoder() 
date_numeric = encoder.fit_transform(date) 
consumption = df[df.columns[0]].values 

X = np.array([date_numeric, consumption]).T 


# code changes here 
# =========================================== 
gmm = GMM(n_components=3, random_state=0) 
gmm.fit(X) 
y_pred = gmm.predict(X) 

# the center is given by mean 
gmm.means_ 

# =========================================== 

import matplotlib as mpl 
fig, ax = plt.subplots(figsize=(10,8)) 

for i, color in enumerate('rgb'): 
    # sphere background 
    width, height = 2 * 1.96 * np.sqrt(np.diagonal(gmm._get_covars()[i])) 
    ell = mpl.patches.Ellipse(gmm.means_[i], width, height, color=color) 
    ell.set_alpha(0.1) 
    ax.add_artist(ell) 
    # data points 
    X_data = X[y_pred == i] 
    ax.scatter(X_data[:,0], X_data[:,1], color=color) 
    # center 
    ax.scatter(gmm.means_[i][0], gmm.means_[i][1], marker='x', s=100, c=color) 


ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold') 
ax.set_xlabel('time', color='gold') 
ax.set_ylabel('date(year 2011)', color='gold') 
a = np.arange(0, len(X), 5) 
ax.set_xticks(a) 
ax.set_xticklabels(encoder.inverse_transform(a.astype(int))) 
ax.tick_params(axis='x', colors='lightseagreen') 
ax.tick_params(axis='y', colors='lightseagreen') 

enter image description here

+0

建勳!!!!謝謝你太多了!我會等你的!!! –

+0

@SuzukiSoma剛剛更新了我的文章。請看一看。 :-) –

+0

你怎麼這麼聰明..非常感謝你! 你是怎麼研究它的? –

相關問題