我必須使用python numpy庫實現隨機梯度下降。爲了這個目的,我給下面的函數定義:Python的numpy隨機梯度下降實現
def compute_stoch_gradient(y, tx, w):
"""Compute a stochastic gradient for batch data."""
def stochastic_gradient_descent(
y, tx, initial_w, batch_size, max_epochs, gamma):
"""Stochastic gradient descent algorithm."""
我也給予以下幫助功能:
def batch_iter(y, tx, batch_size, num_batches=1, shuffle=True):
"""
Generate a minibatch iterator for a dataset.
Takes as input two iterables (here the output desired values 'y' and the input data 'tx')
Outputs an iterator which gives mini-batches of `batch_size` matching elements from `y` and `tx`.
Data can be randomly shuffled to avoid ordering in the original data messing with the randomness of the minibatches.
Example of use :
for minibatch_y, minibatch_tx in batch_iter(y, tx, 32):
<DO-SOMETHING>
"""
data_size = len(y)
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_y = y[shuffle_indices]
shuffled_tx = tx[shuffle_indices]
else:
shuffled_y = y
shuffled_tx = tx
for batch_num in range(num_batches):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
if start_index != end_index:
yield shuffled_y[start_index:end_index], shuffled_tx[start_index:end_index]
我實現了以下兩個功能:
def compute_stoch_gradient(y, tx, w):
"""Compute a stochastic gradient for batch data."""
e = y - tx.dot(w)
return (-1/y.shape[0])*tx.transpose().dot(e)
def stochastic_gradient_descent(y, tx, initial_w, batch_size, max_epochs, gamma):
"""Stochastic gradient descent algorithm."""
ws = [initial_w]
losses = []
w = initial_w
for n_iter in range(max_epochs):
for minibatch_y,minibatch_x in batch_iter(y,tx,batch_size):
w = ws[n_iter] - gamma * compute_stoch_gradient(minibatch_y,minibatch_x,ws[n_iter])
ws.append(np.copy(w))
loss = y - tx.dot(w)
losses.append(loss)
return losses, ws
我不確定迭代應該在範圍內(max_epochs)還是在更大範圍內完成。我這樣說是因爲我讀到一個時代是「每次我們貫穿整個數據集」。所以我認爲一個時代包含更多的迭代......
對於第二個問題:讀了* *批**,**小批**和**時代**關於sgd。 – sascha
您在內部循環中調用'batch_iter',每次調用時都會實例化一個新的生成器對象。相反,你想要在循環外實例化一個單獨的生成器,然後迭代它, '對於minibatch_y,minibatch_x在batch_iter(...)'中。 –