2017-06-21 23 views
0

使用Flink和Java來使我們的推薦系統使用我們的邏輯。如何使用我的模型在Flink中進行分組

所以我有一個數據集:

[user] [item] 
100  1 
100  2 
100  3 
100  4 
100  5 
200  1 
200  2 
200  3 
200  6 
300  1 
300  6 
400  7 

所以我都映射到一個元組:

DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() { 
      @Override 
      public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception { 
       Long customerId = 0L; 
       Long itemId = 0L; 
       Integer count = 0; 

       for (Tuple2<Long, Long> item : iterable) { 
        customerId = item.f0; 
        itemId = item.f1; 
        count = count + 1; 
       } 

       collector.collect(new Tuple3<>(customerId, itemId, count)); 
      } 
    }); 

後我得到的所有客戶和裏面的ArrayList項目:

DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() { 
      @Override 
      public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception { 
       ArrayList<Long> newItems = new ArrayList<>(); 
       Long customerId = 0L; 

       for (Tuple3<Long, Long, Integer> item : iterable) { 
        customerId = item.f0; 
        newItems.add(item.f1); 
       } 

       collector.collect(new CustomerItems(customerId, newItems)); 
      } 
    }); 

現在我需要獲得所有「類似」客戶。但嘗試了很多東西之後,沒有什麼用。

邏輯將是:

for ci : CustomerItems 
    c1 = c1.customerId 

    for ci2 : CustomerItems 
     c2 = ci2.cstomerId 

     if c1 != c2 
     if c2.getItems() have any item inside c1.getItems() 
      collector.collect(new Tuple2<c1, c2>) 

我嘗試了使用減少,但我不能上迭代迭代兩個時間(內部循環迴路)。

任何人都可以幫助我嗎?

回答

0

您可以將數據集與自身交叉,並基本上將您的邏輯1:1插入到交叉函數(不包括2個循環,因爲交叉爲您執行)。

+0

我做了,但結果是16個新項目。我把所有的代碼和結果放在這裏: https://gist.github.com/prsolucoes/b406ae98ea24120436954967e37103f6 –

0

我解決了這個問題,但是我需要group和reduce之後的「cross」。我不知道這是最好的方法。任何人都可以提出建議

其結果是在這裏:

package org.myorg.quickstart; 

import org.apache.flink.api.common.functions.CrossFunction; 
import org.apache.flink.api.common.functions.FlatMapFunction; 
import org.apache.flink.api.common.functions.GroupReduceFunction; 
import org.apache.flink.api.java.DataSet; 
import org.apache.flink.api.java.ExecutionEnvironment; 
import org.apache.flink.api.java.functions.KeySelector; 
import org.apache.flink.api.java.tuple.Tuple2; 
import org.apache.flink.api.java.tuple.Tuple3; 
import org.apache.flink.util.Collector; 

import java.io.Serializable; 
import java.util.ArrayList; 

public class UserRecommendation { 

    public static void main(String[] args) throws Exception { 
     final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); 

     // le o arquivo cm o dataset 
     DataSet<String> text = env.readTextFile("/Users/paulo/Downloads/dataset.csv"); 

     // cria tuple com: customer | item | count 
     DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineFieldSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() { 
      @Override 
      public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception { 
       Long customerId = 0L; 
       Long itemId = 0L; 
       Integer count = 0; 

       for (Tuple2<Long, Long> item : iterable) { 
        customerId = item.f0; 
        itemId = item.f1; 
        count = count + 1; 
       } 

       collector.collect(new Tuple3<>(customerId, itemId, count)); 
      } 
     }); 

     // agrupa os items do customer dentro do customer 
     final DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() { 
      @Override 
      public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception { 
       ArrayList<Long> newItems = new ArrayList<>(); 
       Long customerId = 0L; 

       for (Tuple3<Long, Long, Integer> item : iterable) { 
        customerId = item.f0; 
        newItems.add(item.f1); 
       } 

       collector.collect(new CustomerItems(customerId, newItems)); 
      } 
     }); 

     // obtém todos os itens do customer que pertence a um usuário parecido 
     DataSet<CustomerItems> ci = customerItems.cross(customerItems).with(new CrossFunction<CustomerItems, CustomerItems, CustomerItems>() { 

      @Override 
      public CustomerItems cross(CustomerItems customerItems, CustomerItems customerItems2) throws Exception { 
       if (!customerItems.customerId.equals(customerItems2.customerId)) { 
        boolean has = false; 

        for (Long item : customerItems2.items) { 
         if (customerItems.items.contains(item)) { 
          has = true; 
          break; 
         } 
        } 

        if (has) { 
         for (Long item : customerItems2.items) { 
          if (!customerItems.items.contains(item)) { 
           customerItems.ritems.add(item); 
          } 
         } 
        } 
       } 

       return customerItems; 
      } 

     }).groupBy(new KeySelector<CustomerItems, Long>() { 

      @Override 
      public Long getKey(CustomerItems customerItems) throws Exception { 
       return customerItems.customerId; 
      } 

     }).reduceGroup(new GroupReduceFunction<CustomerItems, CustomerItems>() { 

      @Override 
      public void reduce(Iterable<CustomerItems> iterable, Collector<CustomerItems> collector) throws Exception { 
       CustomerItems c = new CustomerItems(); 

       for (CustomerItems current : iterable) { 
        c.customerId = current.customerId; 

        for (Long item : current.ritems) { 
         if (!c.ritems.contains(item)) { 
          c.ritems.add(item); 
         } 
        } 
       } 

       collector.collect(c); 
      } 

     }); 

     ci.first(100).print(); 
     System.out.println(ci.count()); 
    } 

    public static class CustomerItems implements Serializable { 

     public Long customerId; 
     public ArrayList<Long> items = new ArrayList<>(); 
     public ArrayList<Long> ritems = new ArrayList<>(); 

     public CustomerItems() { 

     } 

     public CustomerItems(Long customerId, ArrayList<Long> items) { 
      this.customerId = customerId; 
      this.items = items; 
     } 

     @Override 
     public String toString() { 
      StringBuilder itemsData = new StringBuilder(); 

      if (items != null) { 
       for (Long item : items) { 
        if (itemsData.length() == 0) { 
         itemsData.append(item); 
        } else { 
         itemsData.append(", ").append(item); 
        } 
       } 
      } 

      StringBuilder ritemsData = new StringBuilder(); 

      if (ritems != null) { 
       for (Long item : ritems) { 
        if (ritemsData.length() == 0) { 
         ritemsData.append(item); 
        } else { 
         ritemsData.append(", ").append(item); 
        } 
       } 
      } 

      return String.format("[ID: %d, Items: %s, RItems: %s]", customerId, itemsData, ritemsData); 
     } 
    } 

    public static final class LineFieldSplitter implements FlatMapFunction<String, Tuple2<Long, Long>> { 

     @Override 
     public void flatMap(String value, Collector<Tuple2<Long, Long>> out) { 
      // normalize and split the line 
      String[] tokens = value.split("\t"); 

      if (tokens.length > 1) { 
       out.collect(new Tuple2<>(Long.valueOf(tokens[0]), Long.valueOf(tokens[1]))); 
      } 
     } 
    } 

} 

鏈接與要點: https://gist.github.com/prsolucoes/b406ae98ea24120436954967e37103f6

相關問題