繼菲利普Malczak的和Seanny123的建議和意見,我在tensorflow實現的神經網絡來檢查當我們試圖教它來預測(和插值)的2次方會發生什麼。
訓練對連續間隔
我訓練網絡上的區間[-7,7-](以300點此區間內,使之連續的),然後測試了它在區間[ - 30,30]。激活函數是ReLu,網絡有3個隱藏層,每個隱藏層的大小爲50個。時期= 500。結果如下圖所示。因此,基本上,在(和也接近)區間[-7,7]內,該擬合是相當完美的,然後它在外部或多或少呈線性延伸。很高興看到至少在最初時,網絡輸出的斜率試圖「匹配」x^2
的斜率。如果我們提高了測試的時間間隔,這兩個圖表分歧頗多,作爲一個可以在下面的圖中看到:
在偶數培訓
最後,如果不是我訓練網絡在間隔[-100,100]的所有偶數集合中,並將其應用於此間隔中的所有整數集合(偶數和奇數),我得到:
當訓練網絡以生成圖像以上,我把時代增加到2500到獲得更好的準確性。其餘參數保持不變。因此,似乎在訓練間隔內「插入」工作得相當好(可能除了0附近的區域,其中擬合稍差)。
這裏是我使用的第一個數字代碼:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.python.framework.ops import reset_default_graph
#preparing training data
train_x=np.linspace(-7,7,300).reshape(-1,1)
train_y=train_x**2
#setting network features
dimensions=[50,50,50,1]
epochs=500
batch_size=5
reset_default_graph()
X=tf.placeholder(tf.float32, shape=[None,1])
Y=tf.placeholder(tf.float32, shape=[None,1])
weights=[]
biases=[]
n_inputs=1
#initializing variables
for i,n_outputs in enumerate(dimensions):
with tf.variable_scope("layer_{}".format(i)):
w=tf.get_variable(name="W",shape=[n_inputs,n_outputs],initializer=tf.random_normal_initializer(mean=0.0,stddev=0.02,seed=42))
b=tf.get_variable(name="b",initializer=tf.zeros_initializer(shape=[n_outputs]))
weights.append(w)
biases.append(b)
n_inputs=n_outputs
def forward_pass(X,weights,biases):
h=X
for i in range(len(weights)):
h=tf.add(tf.matmul(h,weights[i]),biases[i])
h=tf.nn.relu(h)
return h
output_layer=forward_pass(X,weights,biases)
cost=tf.reduce_mean(tf.squared_difference(output_layer,Y),1)
cost=tf.reduce_sum(cost)
optimizer=tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#train the network
for i in range(epochs):
idx=np.arange(len(train_x))
np.random.shuffle(idx)
for j in range(len(train_x)//batch_size):
cur_idx=idx[batch_size*j:(batch_size+1)*j]
sess.run(optimizer,feed_dict={X:train_x[cur_idx],Y:train_y[cur_idx]})
#current_cost=sess.run(cost,feed_dict={X:train_x,Y:train_y})
#print(current_cost)
#apply the network on the test data
test_x=np.linspace(-30,30,300)
network_output=sess.run(output_layer,feed_dict={X:test_x.reshape(-1,1)})
plt.plot(test_x,test_x**2,color='r',label='y=x^2')
plt.plot(test_x,network_output,color='b',label='network output')
plt.legend(loc='center')
plt.show()
什麼是文本變量的數據格式? –
你確定神經網絡應該能夠做到這一點嗎?在我的經驗中,神經網絡並不擅長數學推理。 – Seanny123
@ Seanny123我不知道。我也試圖找到這個 - 如果這對神經網絡或AI來說是一個很好的問題。如果不是NN,線性分類器是否合適? – gammay