2017-08-31 66 views
1

我正在嘗試在Unity中實現DeepQ學習模擬螞蟻。遵循雅閣Animat示例,我設法實現了該算法的要點。捕獲狀態作爲DeepQ中的數組使用Accord.net學習

現在我的代理有5個狀態輸入 - 其中三個來自探測前方障礙物的傳感器(RayCasts在Unity中),剩下的兩個是它在地圖上的X和Y位置。

我的問題是,qLearning.GetAction(currentState)只接受一個int作爲參數。如何使用數組(或張量)爲代理當前狀態實現我的算法?

這是我的代碼:

using System.Collections; 
using System.Collections.Generic; 
using UnityEngine; 
using Accord.MachineLearning; 
using System; 

public class AntManager : MonoBehaviour { 
    float direction = 0.01f; 
    float rotation = 0; 

    // learning settings 
    int learningIterations = 100; 
    private double explorationRate = 0.5; 
    private double learningRate = 0.5; 

    private double moveReward = 0; 
    private double wallReward = -1; 
    private double goalReward = 1; 

    private float lastDistance = 0; 

    private RaycastHit hit; 
    private int hitInteger = 0; 

    // Q-Learning algorithm 
    private QLearning qLearning = null; 


    // Use this for initialization 
    void Start() { 
     qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate))); 
    } 

    // Update is called once per frame 
    void Update() {   

     // curent coordinates of the agent 
     float agentCurrentX = transform.position.x; 
     float agentCurrentY = transform.position.y; 
     // exploration policy 
     TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy; 

     EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy; 

     // set exploration rate for this iteration 
     explorationPolicy.Epsilon = explorationRate - learningIterations * explorationRate; 
     // set learning rate for this iteration 
     qLearning.LearningRate = learningRate - learningIterations * learningRate; 
     // clear tabu list 
     tabuPolicy.ResetTabuList(); 

     // get agent's current state 
     int currentState = ((int)Math.Round(transform.position.x, 0) + (int)Math.Round(transform.position.y, 0) + hitInteger); 
     // get the action for this state 
     int action = qLearning.GetAction(currentState); 
     // update agent's current position and get his reward 
     double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action); 
     // get agent's next state 
     int nextState = currentState; 
     // do learning of the agent - update his Q-function 
     qLearning.UpdateState(currentState, action, reward, nextState); 

     // set tabu action 
     tabuPolicy.SetTabuAction((action + 2) % 4, 1); 


    } 

    // Update agent position and return reward for the move 
    private double UpdateAgentPosition(ref float currentX, ref float currentY, int action) 
    { 
     // default reward is equal to moving reward 
     double reward = moveReward; 
     GameObject food = GameObject.FindGameObjectWithTag("Food"); 

     float distance = Vector3.Distance(transform.position, food.transform.position); 

     if (distance < lastDistance) 
      reward = 0.2f; 

     lastDistance = distance; 

     Debug.Log(distance); 

     switch (action) 
     { 
      case 0:   // go to north (up) 
       rotation += -1f; 
       break; 
      case 1:   // go to east (right) 
       rotation += 1f; 
       break; 
      case 2:   // go to south (down) 
       rotation += 1f; 
       break; 
      case 3:   // go to west (left) 
       rotation += -1f; 
       break; 
     } 

     //transform.eulerAngles = new Vector3(10, rotation, 0); 
     transform.Rotate(0, rotation * Time.deltaTime, 0); 
     transform.Translate(new Vector3(0, 0, 0.01f)); 



     float newX = transform.localRotation.x; 
     float newY = transform.localRotation.y; 

     Ray sensorForward = new Ray(transform.position, transform.forward); 
     Debug.DrawRay(transform.position, transform.forward * 1); 

     if (Physics.Raycast(sensorForward, out hit, 1)) 
     { 
      if (hit.collider.tag != "Terrain") 
      { 
       Debug.Log("Sensor Forward hit!"); 

       reward = wallReward; 
      } 
      if (hit.collider.tag == "Food") 
      { 
       Debug.Log("Sensor Found Food!"); 
       Destroy(food); 
       reward = goalReward; 
       hitInteger = 1; 
      } 
      hitInteger = 0; 
     } 

     return reward; 
    } 
} 

回答

0

documentation提供此作爲一個例子:

c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7) 

這似乎是位移動兩個值的整數進入狀態的二進制編碼。您的代碼可能需要這樣的事:

int currentState = ((int)Math.Round(transform.position.x, 0) | ((int)Math.Round(transform.position.y, 0) << 1) | (hitInteger << 2)) 

但是,您首先需要將您的狀態映射到二元變量,所以這段代碼將只與一個2x2的網格工作。儘管該示例聲明瞭整數,但它們是二進制值:將位移爲2或更大的值是沒有意義的。

Convert.ToString(1 | (0 << 1) | (1 << 2), 2) 

以可視化的狀態,一個有用的方法是直接在二進制找