2015-04-30 57 views

回答

6
import org.apache.spark.mllib.linalg.{Vectors,Vector,Matrix,SingularValueDecomposition,DenseMatrix,DenseVector} 
import org.apache.spark.mllib.linalg.distributed.RowMatrix 

def computeInverse(X: RowMatrix): DenseMatrix = { 
    val nCoef = X.numCols.toInt 
    val svd = X.computeSVD(nCoef, computeU = true) 
    if (svd.s.size < nCoef) { 
    sys.error(s"RowMatrix.computeInverse called on singular matrix.") 
    } 

    // Create the inv diagonal matrix from S 
    val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x,-1)))) 

    // U cannot be a RowMatrix 
    val U = new DenseMatrix(svd.U.numRows().toInt,svd.U.numCols().toInt,svd.U.rows.collect.flatMap(x => x.toArray)) 

    // If you could make V distributed, then this may be better. However its alreadly local...so maybe this is fine. 
    val V = svd.V 
    // inv(X) = V*inv(S)*transpose(U) --- the U is already transposed. 
    (V.multiply(invS)).multiply(U) 
    } 
3

我使用這個功能與選項

conf.set("spark.sql.shuffle.partitions", "12") 

在RowMatrix該行得到了洗牌有問題。

下面是一個更新爲我工作

import org.apache.spark.mllib.linalg.{DenseMatrix,DenseVector} 
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix 

def computeInverse(X: IndexedRowMatrix) 
: DenseMatrix = 
{ 
    val nCoef = X.numCols.toInt 
    val svd = X.computeSVD(nCoef, computeU = true) 
    if (svd.s.size < nCoef) { 
    sys.error(s"IndexedRowMatrix.computeInverse called on singular matrix.") 
    } 

    // Create the inv diagonal matrix from S 
    val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x, -1)))) 

    // U cannot be a RowMatrix 
    val U = svd.U.toBlockMatrix().toLocalMatrix().multiply(DenseMatrix.eye(svd.U.numRows().toInt)).transpose 

    val V = svd.V 
    (V.multiply(invS)).multiply(U) 
} 
0

矩陣U由X.computeSVD返回的尺寸爲MXK其中是人們期望米原(分佈式)RowMatrix X的行數很大(可能大於k),所以如果我們希望我們的代碼縮放到非常大的值m,則不建議將它收集在驅動程序中。

我想說下面的兩個解決方案都會遇到這個缺陷。由@Alexander Kharlamov給出的答案叫做val U = svd.U.toBlockMatrix().toLocalMatrix(),它收集驅動程序中的矩陣。 @Climbs_lika_Spyder給出的答案也是一樣(順便說一句,你的暱稱是岩石!!),它叫svd.U.rows.collect.flatMap(x => x.toArray)。我寧願建議依靠分佈式矩陣乘法,例如發佈了here的Scala代碼。

+0

我沒有看到您添加的鏈接上的任何逆向計算。 –

+0

@Climbs_lika_Spyder該鏈接是關於分佈式矩陣乘法,用於替換解決方案最後一行中的局部矩陣乘法'(V.multiply(invS))。multiply(U)',這樣就不需要收集'U'在司機。我認爲'V'和'invS'不夠大會導致問題。 – Pablo

相關問題