2012-10-10 54 views
2

我寫了Pybrain神經網絡的這個簡單測試,但它並沒有像我期望的那樣工作。這個想法是訓練一個數字集到4095的數據集,其中包含素數和非素數。神經網絡對不同的激活報告相同的響應

#!/usr/bin/env python 
# A simple feedforward neural network that attempts to learn Primes 

from pybrain.datasets import ClassificationDataSet 
from pybrain.tools.shortcuts import buildNetwork 
from pybrain.supervised import BackpropTrainer 

class PrimesDataSet(ClassificationDataSet): 
    """ A dataset for primes """ 

    def generatePrimes(self, n): 
     if n == 2: 
      return [2] 
     elif n < 2: 
      return [] 
     s = range(3, n + 1, 2) 
     mroot = n ** 0.5 
     half = (n + 1)/2 - 1 
     i = 0 
     m = 3 
     while m <= mroot: 
      if s[i]: 
       j = (m * m - 3)/2 
       s[j] = 0 
       while j < half: 
        s[j] = 0 
        j += m 
      i = i + 1 
      m = 2 * i + 3 
     return [2] + [x for x in s if x] 

    def binaryString(self, n): 
     return "{0:12b}".format(n) 

    def __init__(self): 
     ClassificationDataSet.__init__(self, 12, 1) 
     primes = self.generatePrimes(4095) 
     for prime in primes: 
      b = self.binaryString(prime).split() 
      self.addSample(b, [1]) 
     for n in range(4095): 
      if n not in primes: 
       b = self.binaryString(n).split() 
       self.addSample(b, [0]) 

def testTraining(): 
    d = PrimesDataSet() 
    d._convertToOneOfMany() 
    n = buildNetwork(d.indim, 12, d.outdim, recurrent=True) 
    t = BackpropTrainer(n, learningrate = 0.01, momentum = 0.99, verbose = True) 
    t.trainOnDataset(d, 100) 
    t.testOnData(verbose=True) 
    print "Is 7 prime? ", n.activate(d.binaryString(7).split()) 
    print "Is 6 prime? ", n.activate(d.binaryString(6).split()) 
    print "Is 100 prime? ", n.activate(d.binaryString(100).split()) 


if __name__ == '__main__': 
    testTraining() 

直索(請)這是否甚至有可能的問題,我的問題是,7,6,和100的最後三個報表打印測試是否是質都返回相同的:

Is 7 prime? [ 0.34435841 0.65564159] 
Is 6 prime? [ 0.34435841 0.65564159] 
Is 100 prime? [ 0.34435841 0.65564159] 

(或類似的東西) 我解釋這些結果的方式是神經網絡以65%的確定性預測這些數字中的每一個是是質數。我的神經網絡學會了如何處理所有的輸入,或者我做錯了什麼?

回答

0

看起來你實際上只使用一個輸入。

d.binaryString(7).split() 

相當於

"{0:12b}".format(7).split() 

計算結果爲

['111']. 

我想你打算是像

[int(c) for c in "{0:012b}".format(7)] 

其結果是

[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1] 

P.S.檢查你輸入統計模型到底是什麼總是個好主意:)

+0

謝謝!最簡單的事情 - 我甚至沒想過要檢查。 – lambda