-1
我正在嘗試爲函數y=e^(-(x-u)^2)/(2*o^2))
創建一個神經網絡,其中u = 50
和o = 15
。神經網絡,獲取輸出小於1
我必須訓練我的神經網絡,所以我可以找到每個y的2個x。我創建了folling代碼,它似乎很好地學習它,但是一旦我測試輸出結果,我只能得到大約0.99到1的數字,我應該得到25和75,我只是不明白爲什麼。我最好的猜測是我的錯誤更正是錯誤的,但找不到錯誤。神經網絡使用反向傳播。
測試代碼和培訓設置
class Program
{
static void Main(string[] args)
{
args = new string[] {
"c:\\testTrain.csv",
"c:\\testValues.csv"
};
// Output File
string fileTrainPath = null;
string fileValuesPath = null;
if (args.Length > 0)
{
fileTrainPath = args[0];
if (File.Exists(fileTrainPath))
File.Delete(fileTrainPath);
fileValuesPath = args[1];
if (File.Exists(fileValuesPath))
File.Delete(fileValuesPath);
}
double learningRate = 0.1;
double u = 50;
double o = 15;
Random rand = new Random();
Network net = new Network(1, 8, 4, 2);
NetworkTrainer netTrainer = new NetworkTrainer(learningRate, net);
List<TrainerSet> TrainerSets = new List<TrainerSet>();
for(int i = 0; i <= 20; i++)
{
double random = rand.NextDouble();
TrainerSets.Add(new TrainerSet(){
Inputs = new double[] { random },
Outputs = getX(random, u, o)
});
}
// Train Network
string fileTrainValue = String.Empty;
for (int i = 0; i <= 10000; i++)
{
if (i == 5000)
{ }
double error = netTrainer.RunEpoch(TrainerSets);
Console.WriteLine("Epoch " + i + ": Error = " + error);
if(fileTrainPath != null)
fileTrainValue += i + "," + learningRate + "," + error + "\n";
}
if (fileTrainPath != null)
File.WriteAllText(fileTrainPath, fileTrainValue);
// Test Network
string fileValuesValue = String.Empty;
for (int i = 0; i <= 100; i++)
{
double y = rand.NextDouble();
double[] dOutput = getX(y, u, o);
double[] Output = net.Compute(new double[] { y });
if (fileValuesPath != null)
fileValuesValue += i + "," + y + "," + dOutput[0] + "," + dOutput[1] + "," + Output[0] + "," + Output[1] + "\n";
}
if (fileValuesPath != null)
File.WriteAllText(fileValuesPath, fileValuesValue);
}
public static double getResult(int x, double u, double o)
{
return Math.Exp(-Math.Pow(x-u,2)/(2*Math.Pow(o,2)));
}
public static double[] getX(double y, double u, double o)
{
return new double[] {
u + Math.Sqrt(2 * Math.Pow(o, 2) * Math.Log(1/y)),
u - Math.Sqrt(2 * Math.Pow(o, 2) * Math.Log(1/y)),
};
}
}
的代碼在網絡背後
public class Network
{
protected int inputsCount;
protected int layersCount;
protected NetworkLayer[] layers;
protected double[] output;
public int Count
{
get
{
return layers.Count();
}
}
public NetworkLayer this[int index]
{
get { return layers[index]; }
}
public Network(int inputsCount, params int[] neuronsCount)
{
this.inputsCount = Math.Max(1, inputsCount);
this.layersCount = Math.Max(1, neuronsCount.Length);
layers = new NetworkLayer[neuronsCount.Length];
for (int i = 0; i < layersCount; i++)
layers[i] = new NetworkLayer(neuronsCount[i],
(i == 0) ? inputsCount : neuronsCount[i - 1]);
}
public virtual double[] Compute(double[] input)
{
output = input;
foreach (NetworkLayer layer in layers)
output = layer.Compute(output);
return output;
}
}
public class NetworkLayer
{
protected int inputsCount = 0;
protected int neuronsCount = 0;
protected Neuron[] neurons;
protected double[] output;
public Neuron this[int index]
{
get { return neurons[index]; }
}
public int Count
{
get { return neurons.Length; }
}
public int Inputs
{
get { return inputsCount; }
}
public double[] Output
{
get { return output; }
}
public NetworkLayer(int neuronsCount, int inputsCount)
{
this.inputsCount = Math.Max(1, inputsCount);
this.neuronsCount = Math.Max(1, neuronsCount);
neurons = new Neuron[this.neuronsCount];
output = new double[this.neuronsCount];
// create each neuron
for (int i = 0; i < neuronsCount; i++)
neurons[i] = new Neuron(inputsCount);
}
public virtual double[] Compute(double[] input)
{
// compute each neuron
for (int i = 0; i < neuronsCount; i++)
output[i] = neurons[i].Compute(input);
return output;
}
}
public class Neuron
{
protected static Random rand = new Random((int)DateTime.Now.Ticks);
public int Inputs;
public double[] Input;
public double[] Weights;
public double Output = 0;
public double Threshold;
public double Error;
public Neuron(int inputs)
{
this.Inputs = inputs;
Weights = new double[inputs];
for (int i = 0; i < inputs; i++)
Weights[i] = rand.NextDouble() * 0.5;
}
public double Compute(double[] inputs)
{
Input = inputs;
double e = 0.0;
for (int i = 0; i < inputs.Length; i++)
e += Weights[i] * inputs[i];
e -= Threshold;
return (Output = sigmoid(e));
}
private double sigmoid(double value)
{
return (1/(1 + Math.Exp(-1 * value)));
//return 1/(1 + Math.Exp(-value));
}
}
我的教練
public class NetworkTrainer
{
private Network network;
private double learningRate = 0.1;
public NetworkTrainer(double a, Network network)
{
this.network = network;
this.learningRate = a;
}
public double Run(double[] input, double[] output)
{
network.Compute(input);
return CorrectErrors(output);
}
public double RunEpoch(List<TrainerSet> sets)
{
double error = 0.0;
for (int i = 0, n = sets.Count; i < n; i++)
error += Run(sets[i].Inputs, sets[i].Outputs);
// return summary error
return error;
}
private double CorrectErrors(double[] desiredOutput)
{
double[] errorLast = new double[desiredOutput.Length];
NetworkLayer lastLayer = network[network.Count - 1];
for (int i = 0; i < desiredOutput.Length; i++)
{
// S(p)=y(p)*[1-y(p)]*(yd(p)-y(p))
lastLayer[i].Error = lastLayer[i].Output * (1-lastLayer[i].Output)*(desiredOutput[i] - lastLayer[i].Output);
errorLast[i] = lastLayer[i].Error;
}
// Calculate errors
for (int l = network.Count - 2; l >= 0; l--)
{
for (int n = 0; n < network[l].Count; n++)
{
double newError = 0;
for (int np = 0; np < network[l + 1].Count; np++)
{
newError += network[l + 1][np].Weights[n] * network[l + 1][np].Error;
}
network[l][n].Error = newError;
}
}
// Update Weights
// w = w + (a * input * error)
for (int l = network.Count - 1; l >= 0; l--)
{
for (int n = 0; n < network[l].Count; n++)
{
for (int i = 0; i < network[l][n].Inputs; i++)
{
// deltaW = a * y(p) * s(p)
double deltaW = learningRate * network[l][n].Output * network[l][n].Error;
network[l][n].Weights[i] += deltaW;
}
}
}
double returnError = 0;
foreach (double e in errorLast)
returnError += e;
return returnError;
}
}
您遇到的具體問題是什麼?這是相當長的代碼.. http://sscce.org/ – Coffee 2012-04-18 20:22:42
正如我所說,我只得到0.99的輸出,它應該是eks。 75或25,我不知道但我必須執行錯誤。 – Androme 2012-04-18 20:37:40
我被告知,這是因爲我的激活函數給出了一個介於0和1之間的值,我需要調整它。但我找不到任何關於此的內容。 – Androme 2012-04-18 20:53:37