2016-09-24 34 views
1

我無法理解二維中的梯度下降。說我有功能f(x,y)=x**2-xy其中df/dx = 2x-ydf/dy = -xPython中兩個維度的漸變下降

因此,對於點df(2,3),輸出向量是[1,-2] .T。矢量[1,-2]所指向的任何位置都在最陡上升(f(x,y)的輸出)的方向上。 我應該選擇一個固定的步長,並找到方向,這樣一個步驟的大小增加f(x,y)最多。如果我想下降,我想找到最快增加-f(x,y)的方向?

如果我的直覺是對的,你會如何編碼?假設我從點(x = 0,y = 5)開始,我想執行梯度下降以找到最小值。

step_size = 0.01 
precision = 0.00001 #stopping point 
enter code here?? 

回答

1

這裏是梯度下降與matplotlib可視化的實現:

import csv 
import math 
def loadCsv(filename): 
    lines = csv.reader(open(filename, "r")) 
    dataset = list(lines) 
    for i in range(len(dataset)): 
     dataset[i] = [float(x) for x in dataset[i]] 
    return dataset 

def h(o1,o2,x): 
    ans=o1+o2*x 
    return ans 

def costf(massiv,p1,p2): 
    sum1=0.0 
    sum2=0.0 
    for x,y in massiv: 
     sum1+=(math.pow(h(o1,o2,x)-y,2)) 
    sum2=(1.0/(2*len(massiv)))*sum1 
    return sum1,sum2 

def gradient(massiv,er,alpha,o1,o2,max_loop=1000): 
    i=0 
    J,e=costf(massiv,o1,o2) 
    conv=False 
    m=len(massiv) 
    while conv!=True: 
     sum1=0.0 
     sum2=0.0 
     for x,y in massiv: 
      sum1+=(o1+o2*x-y) 
      sum2+=(o1+o2*x-y)*x 
     grad0=1.0/m*sum1 
     grad1=1.0/m*sum2 

     temp0=o1-alpha*grad0 
     temp1=o2-alpha*grad1 
     print(temp0,temp1) 
     o1=temp0 
     o2=temp1 
     e=0.0 
     for x,y in massiv: 
      e+=(math.pow(h(o1,o2,x)-y,2)) 
     if abs(J-e)<=ep: 
      print('Successful\n') 
      conv=True 

     J=e 

     i+=1 
     if i>=max_loop: 
      print('Too much\n') 
      break 
    return o1,o2 


#data = massiv 
data=loadCsv('ex1data1.txt') 
o1=0.0 #temp0=0 
o2=1.0 #temp1=1 
alpha=0.01 
ep=0.01 
t0,t1=gradient(data,ep,alpha,o1,o2) 
print('temp0='+str(t0)+' \ntemp1='+str(t1)) 

x=35000 
while x<=70000: 
    y=h(t0,t1,x) 
    print('x='+str(x)+'\ny='+str(y)+'\n') 
    x+=5000 

maxx=data[0][0] 
for q,w in data: 
    maxx=max(maxx,q) 
maxx=round(maxx)+1 
line=[] 
ll=0 
while ll<maxx: 
    line.append(h(t0,t1,ll)) 
    ll+=1 
x=[] 
y=[] 
for q,w in data: 
    x.append(q) 
    y.append(w) 

import matplotlib.pyplot as plt 
plt.plot(x,y,'ro',line) 
plt.ylabel('some numbers') 
plt.show() 

Matplotlib輸出:

enter image description here

ex1data1.txt可以從這裏纔可下載: ex1data1.txt

該代碼可以在Python 3.5的Anaconda發行版中按原樣執行。