`
kavy
  • 浏览: 866826 次
  • 性别: Icon_minigender_1
  • 来自: 上海
社区版块
存档分类
最新评论

基于Spark的TF-IDF算法的中文文本相似度实现

 
阅读更多

第一、先说下应用场景吧,用户给出一段文字然后我返回十个与这段文字最相似的文件名称。

 

第二、什么是TF-IDF算法?我就简单介绍一下,因为百度上也有许多的介绍,TF-IDF用中文来解释就是词频、逆文档频率的意思,TF-IDF体现了词项与文档关联度的直观理解,一个词在文档中出现的越多则越重要,但是词项是不平等的文档中出现罕见词项的意义比常见词更大,因此就要取词项出现次数的倒数,词项在语料库中的分布式呈指数型的一个常用词出现次数往往是罕见词的数十倍,假如直接除以原始文档的频率则罕见词的权重就会过大,所以算法应对逆文档率取对数,让文档频率差别由乘数级变为加数级。

 

第三、由于spark MLlib库已经有TF-IDF算法的实现我们就不亲自实现了,而是去调用它的,对于整个功能的逻辑如下:

 

(1)获取数据、(2)用中文分词工具分词(这里使用的是ansj)、(3)计算TF、IDF、(4)计算向量间的余弦相似度

 

 

 

废话不多说直接上代码:

 

 

 

 

 

import org.ansj.recognition.impl.StopRecognition

 

import org.ansj.splitWord.analysis.ToAnalysis

 

import org.apache.spark.mllib.feature.{ HashingTF, IDF }

 

import org.apache.spark.{ SparkConf, SparkContext }

 

import org.apache.spark.mllib.linalg.{ SparseVector => SV }

 

 

 

import scala.collection.mutable.ArrayBuffer

 

import scala.util.matching.Regex

 

object TfIdfs {

 

 

 

  def main(args: Array[String]): Unit = {

 

    val conf =new SparkConf().setAppName("tdidf")

 

    val sc =new SparkContext(conf)

 

//读取2600份法律案例

 

    valrdd = sc.wholeTextFiles("hdfs://master:8020/data2/*")

 

 

 

    val text =rdd.map { case (file,text) => text }

 

 

 

    val title =rdd.map { case (title,text) => title }.collect()

 

 

 

    val dim = math.pow(2,18).toInt

 

    val hashingTF =new HashingTF(dim)

 

 

 

    val filter =new StopRecognition()

 

    filter.insertStopNatures("w")//过滤掉标点

 

//使用ansj对文本进行分词

 

    val tokens2 =text.map(doc => ToAnalysis.parse(doc).recognition(filter).toStringWithOutNature(" ").split(" ").toSeq)

 

//tf计算

 

    val tf =hashingTF.transform(tokens2)

 

 

 

    // cache data in memory

 

    tf.cache

 

    //idf计算

 

    val idf =new IDF().fit(tf)

 

    val tfidf =idf.transform(tf)

 

    val vectorArr =tfidf.collect()

 

//需要匹配相似度的文本

 

    valrdd3 = sc.parallelize(Seq("被告人王乃胜,高中文化,户籍所在地:白山市,居住地:白山市。"))

 

 

 

val predictTF2 =rdd3.map(doc => hashingTF.transform(ToAnalysis.parse(doc).recognition(filter).toStringWithOutNature(" ").split(" ").toSeq))

 

 

 

    valpredictTfIdf = idf.transform(predictTF2)

 

 

 

    import breeze.linalg._

 

    val predictSV =predictTfIdf.first.asInstanceOf[SV]

 

    val c =new ArrayBuffer[(String, Double)]()

 

  

 

    valbreeze1 = new SparseVector(predictSV.indices,predictSV.values,predictSV.size)

 

    var tag =0

 

   

 

    for (i <-vectorArr) {

 

 

 

      val Svector =i.asInstanceOf[SV]

 

      val breeze2 =new SparseVector(Svector.indices,Svector.values, Svector.size)

 

      valcosineSim = breeze1.dot(breeze2) / (norm(breeze1) * norm(breeze2))

 

      c ++= Array((title(tag),cosineSim))

 

      tag += 1

 

 

 

    }

 

    val cc =c.toArray.sortBy(_._2).reverse.take(10)

 

    println(cc.toBuffer)

 

    sc.stop()

 

 

 

  }

 

}

————————————————

 

原文链接:https://blog.csdn.net/cap3396g/article/details/79256625

 

近期在负责公司的POI领域,全称为point of intrest即兴趣点,这个应用的最广泛的应该是地图行业,地图里每一个地址标注即为一个POI,在我们公司对它的含义进行了精简以契合公司业务的发展,将兴趣点集中在了餐饮及新零食相关的商户&超市等范畴。
听上去这个业务只是做一些商户数据的收集校正,那为什么这个业务会牵扯到了机器学习呢?真实原因很尴尬不便多说,目前我们拿到了一些商户的数据,但是无法获取品类,而品类对于我们当前业务来说非常重要,涉及到不同业务线的利益问题。所以需要通过一些特殊手段来识别出这些商户的品类。

场景

通过已有的商户数据,包括商户名称、商户菜品,识别出该商户属于什么品类(川湘菜、日料等)

解决办法

  1. 委托算法团队协助计算商户品类
    已有相似项目,可基于该项目做品类计算,实现快速
  2. 自己琢磨,研究算法
    学习算法以及应用到生产,较为耗时

最终选择

算法团队在人力资源安排上出现问题,不得已由自己来做算法计算,好在相关算法项目交接到我们团队,可以借此做为参考。
考虑到当前快速实现以及以后发展方向,最终选择两路同时并进,一路基于python脚本进行计算用于短期快速实现,另一路尝试通过spark ML进行分布式计算为我们长期目标。

算法思路

将商户的所有的菜品信息以及商户信息当成一串文本来处理,这样就可以把这个问题看成是“文本相似度”的问题,TF-IDF、LSI、LDA等一系列算法可以参考,python脚本采用TF-IDF和LSI来进行相似度计算(简单示例参考)。spark采用TF-IDF以及余弦相似度进行验证性计算(后续优化算法)。因长期规则为使用spark进行机器学习的相关计算,以下重点介绍spark上如何应用。

算法介绍

TF-IDF
余弦相似度计算

spark ML实现

分为两个spark任务,第一个任务为计算当前线上已经存在的且正确匹配的商户&菜品的TF-IDF值并且将计算出的值保存到hive表里。

任务一

数据预处理以及参考数据TF-IDF计算,通过计算
首先通过hive任务将商户的菜品数据拍平,这一步很简单,拍平后的数据如下:


 
商户菜品信息

然后另起spark任务对商户菜品进行TF-IDF处理,将结果保存到如下表里面。vector_indices及vector_values都为数组且长度一致,两者共同表示为多组向量


 
hive表结构设计

如下为tfidf工具类

public class TfidfUtil {
    /**
     * visit below website to get more detail about tfidf
     * @see  <a href="http://dblab.xmu.edu.cn/blog/1261-2/">Spark入门:特征抽取: TF-IDF</a>
     * @param dataset
     * @return
     */
    public static Dataset<Row> tfidf(Dataset<Row> dataset) {
        Tokenizer tokenizer = new Tokenizer().setInputCol("goodsSegment").setOutputCol("words");
        Dataset<Row> wordsData = tokenizer.transform(dataset);
        HashingTF hashingTF = new HashingTF()
                .setInputCol("words")
                .setOutputCol("rawFeatures");
        Dataset<Row> featurizedData = hashingTF.transform(wordsData);
        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        IDFModel idfModel = idf.fit(featurizedData);
        return idfModel.transform(featurizedData);
    }
}

如下为spark预处理任务,主要步骤为获取商户及拍平的菜品数据,再做TF-IDF,再保存到hive表

public class CategorySuggestionTrainning {

    private static SparkSession spark;

    private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();

    public static final String TRAINNING_DATA_SQL = "select id, coalesce(shop_name,'') as name,coalesce(category_id,0) as category_id, coalesce(food_name,'') as food " +
            "from dw.category_and_foodname where dt='%s' limit 100000";

    public static void main(String[] args){
        spark = initSaprk();
        try{
            Dataset<Row> rawTranningDataset = getTrainDataSet();
            Dataset<Row> trainningTfidfDataset = TfidfUtil.tfidf(rawTranningDataset);
            JavaRDD<TrainningFeature> trainningFeatureRdd = getTrainningFeatureRDD(trainningTfidfDataset);
            Dataset<Row> trainningFeaturedataset = spark.createDataFrame(trainningFeatureRdd,TrainningFeature.class);

            saveToHive(trainningFeaturedataset);
            System.out.println("poi suggest trainning stopped");
            spark.stop();
        } catch (Exception e) {
            System.out.println("main method has error " + e.getMessage());
            e.printStackTrace();
        }
    }

    /**
     * to get the origin ele shop data including category and goods which is separated by '|'
     * and then divide the goods into words
     * @return Dataset<Row>
     */
    private static Dataset<Row> getTrainDataSet(){
        String trainningSql = String.format(TRAINNING_DATA_SQL,YESTERDAY);
        System.out.println("tranningData sql is "+trainningSql);
        spark.sql("use dw");
        Dataset<Row> rowRdd = spark.sql(trainningSql);
        JavaRDD<TrainningData> trainningDataJavaRDD = rowRdd.javaRDD().map((row) -> {
            String goods = (String) row.getAs("food");
            String shopName = (String) row.getAs("name");
            if (StringUtil.isBlank(shopName) || StringUtil.isBlank(goods) || goods.length() < 50) {
                System.out.println("some field is null " + row.toString());
                return null;
            }
            TrainningData data = new TrainningData();
            data.setShopId((Long) row.getAs("id"));
            data.setShopName(shopName);
            data.setCategory((Long) row.getAs("category_id"));

            List<Word> words = WordSegmenter.seg(goods);
            StringBuilder wordsOfGoods = new StringBuilder();
            for (Word word : words) {
                wordsOfGoods.append(word.getText()).append(" ");
            }
            data.setGoodsSegment(wordsOfGoods.toString());
            return data;
        }).filter((data) -> data != null);
        return spark.createDataFrame(trainningDataJavaRDD, TrainningData.class);
    }

    private static JavaRDD<TrainningFeature> getTrainningFeatureRDD(Dataset<Row> trainningTfidfDataset){
        return trainningTfidfDataset.javaRDD().map(new Function<Row, TrainningFeature>(){
            @Override
            public TrainningFeature call(Row row) throws Exception {
                TrainningFeature data = new TrainningFeature();
                data.setCategory(row.getAs("category"));
                data.setShopId(row.getAs("shopId"));
                data.setShopName(row.getAs("shopName"));
                SparseVector vector = row.getAs("features");
                data.setVectorSize(vector.size());
                data.setVectorIndices(Arrays.toString(vector.indices()));
                data.setVectorValues(Arrays.toString(vector.values()));
                return data;
            }
        });
    }

    private static SparkSession initSaprk(){
        long startTime = System.currentTimeMillis();
        return SparkSession
                .builder()
                .appName("poi-spark-trainning")
                .enableHiveSupport()
                .getOrCreate();
    }

    private static void saveToHive(Dataset<Row> trainningTfidfDataset){
        try {
            trainningTfidfDataset.createTempView("trainData");
            String sqlInsert = "insert overwrite table dw.poi_category_pre_data " +
                    "select shopId,shopName,category,vectorSize,vectorIndices,vectorValues from trainData ";

            spark.sql("use dw");
            System.out.println(spark.sql(sqlInsert).count());
        } catch (AnalysisException e) {
            System.out.println("save tranning data to hive failed");
            e.printStackTrace();
        }
    }
}

任务二

取出预处理好的数据以及待确定分类的商户数据,将两者做余弦相似度计算,选择相似度最高的预处理的商户的分类做为待确认商户的分类。
相关代码如下

public class CategorySuggestion {

    private static SparkSession spark;

    private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();

    private static boolean CALCULATE_ALL = false;

    private static long MT_SHOP_COUNT = 2000;

    public static final String TRAINNING_DATA_SQL = "select shop_id, coalesce(shop_name,'') as shop_name,coalesce(category,0) as category_id, " +
            "vector_size, coalesce(vector_indices,'[]') as vector_indices, coalesce(vector_values,'[]') as vector_values " +
            "from dw.poi_category_pre_data limit %s ";

    public static final String COMPETITOR_DATA_SQL = "select id,coalesce(name,'') as name,coalesce(food,'') as food from dw.unknow_category_restaurant " +
            "where dt='%s' and id is not null limit %s ";

    public static void main(String[] args){
        spark = initSaprk();
        try{
            MiniTrainningData[] miniTrainningDataArray = getTrainData();
            final Broadcast<MiniTrainningData[]> trainningData = spark.sparkContext().broadcast(miniTrainningDataArray, ClassTag$.MODULE$.<MiniTrainningData[]>apply(MiniTrainningData[].class));
            System.out.println("broadcast success and list is "+trainningData.value().length);

            Dataset<Row> rawMeituanDataset = getMeituanDataSet();
            Dataset<Row> meituanTfidDataset = TfidfUtil.tfidf(rawMeituanDataset);
            Dataset<SimilartyData> similartyDataList = pickupTheTopSimilarShop(meituanTfidDataset, trainningData);

            saveToHive(similartyDataList);
            System.out.println("poi suggest stopped");
            spark.stop();
        } catch (Exception e) {
            System.out.println("main method has error " + e.getMessage());
            e.printStackTrace();
        }
    }

    private static SparkSession initSaprk(){
        long startTime = System.currentTimeMillis();
        return SparkSession
                .builder()
                .appName("poi-spark")
                .enableHiveSupport()
                .getOrCreate();
    }

    /**
     * to get the origin ele shop data including category and goods which is separated by '|'
     * and then divide the goods into words
     * @return Dataset<Row>
     */
    private static MiniTrainningData[] getTrainData(){
        String trainningSql = String.format(TRAINNING_DATA_SQL,20001);
        System.out.println("tranningData sql is "+trainningSql);
        spark.sql("use dw");
        Dataset<Row> rowRdd = spark.sql(trainningSql);
        List<MiniTrainningData> trainningDataList = rowRdd.javaRDD().map((row) -> {
            MiniTrainningData data = new MiniTrainningData();
            data.setEleShopId( row.getAs("shop_id"));
            data.setCategory( row.getAs("category_id"));
            Long vectorSize = row.getAs("vector_size");
            List<Integer> vectorIndices = JSON.parseArray(row.getAs("vector_indices"),Integer.class);
            List<Double> vectorValues = JSON.parseArray(row.getAs("vector_values"),Double.class);
            SparseVector vector = new SparseVector(vectorSize.intValue(),integerListToArray(vectorIndices),doubleListToArray(vectorValues));
            data.setFeatures(vector);
            return data;
        }).collect();
        MiniTrainningData[] miniTrainningDataArray = new MiniTrainningData[trainningDataList.size()];
        return trainningDataList.toArray(miniTrainningDataArray);
    }

    private static int[] integerListToArray(List<Integer> integerList){
        int[] intArray = new int[integerList.size()];
        for (int i = 0; i < integerList.size(); i++) {
            intArray[i] = integerList.get(i).intValue();
        }
        return intArray;
    }

    private static double[] doubleListToArray(List<Double> doubleList){
        double[] doubleArray = new double[doubleList.size()];
        for (int i = 0; i < doubleList.size(); i++) {
            doubleArray[i] = doubleList.get(i).intValue();
        }
        return doubleArray;
    }

    private static Dataset<Row> getMeituanDataSet() {
        String meituanSql = String.format(COMPETITOR_DATA_SQL, YESTERDAY, 10000);
        System.out.println("meituan sql is " + meituanSql);
        spark.sql("use dw");
        Dataset<Row> rowRdd = spark.sql(meituanSql);
        JavaRDD<MeiTuanData> meituanDataJavaRDD = rowRdd.javaRDD().map((row) -> {
            MeiTuanData data = new MeiTuanData();
            String goods = (String) row.getAs("food");
            String shopName = (String) row.getAs("name");
            data.setShopId((Long) row.getAs("id"));
            data.setShopName(shopName);
            if (StringUtil.isBlank(goods)) {
                return null;
            }
            StringBuilder wordsOfGoods = new StringBuilder();
            try {
                List<Word> words = WordSegmenter.seg(goods.replace("|", " "));
                for (Word word : words) {
                    wordsOfGoods.append(word.getText()).append(" ");
                }
            } catch (Exception e) {
                System.out.println("exception in segment " + data);
            }
            data.setGoodsSegment(wordsOfGoods.toString());
            return data;
        }).filter((data) -> data != null);
        System.out.println("meituan data count is " + meituanDataJavaRDD.count());
        return spark.createDataFrame(meituanDataJavaRDD, MeiTuanData.class);
    }

    private static Dataset<SimilartyData> pickupTheTopSimilarShop(Dataset<Row> meituanTfidDataset, Broadcast<MiniTrainningData[]> trainningData){
         return meituanTfidDataset.map(new MapFunction<Row, SimilartyData>() {
            @Override
            public SimilartyData call(Row row) throws Exception {
                SimilartyData similartyData = new SimilartyData();
                Long mtShopId = row.getAs("shopId");
                Vector meituanfeatures = row.getAs("features");
                similartyData.setMtShopId(mtShopId);
                MiniTrainningData[] trainDataArray = trainningData.value();
                if(ArrayUtils.isEmpty(trainDataArray)){
                    return similartyData;
                }
                double maxSimilarty = 0;
                long maxSimilarCategory = 0L;
                long maxSimilareleShopId = 0;
                for (MiniTrainningData trainData : trainDataArray) {
                    Vector trainningFeatures = trainData.getFeatures();
                    long categoryId = trainData.getCategory();
                    long eleShopId = trainData.getEleShopId();
                    double dot = BLAS.dot(meituanfeatures.toSparse(), trainningFeatures.toSparse());
                    double v1 = Vectors.norm(meituanfeatures.toSparse(), 2.0);
                    double v2 = Vectors.norm(trainningFeatures.toSparse(), 2.0);
                    double similarty = dot / (v1 * v2);
                    if(similarty>maxSimilarty){
                        maxSimilarty = similarty;
                        maxSimilarCategory = categoryId;
                        maxSimilareleShopId = eleShopId;
                    }
                }
                similartyData.setEleShopId(maxSimilareleShopId);
                similartyData.setSimilarty(maxSimilarty);
                similartyData.setCategoryId(maxSimilarCategory);
                return similartyData;
            }
        }, Encoders.bean(SimilartyData.class));
    }

    private static void saveToHive(Dataset<SimilartyData> similartyDataset){
        try {
            similartyDataset.createTempView("records");
            String sqlInsert = "insert overwrite table dw.poi_category_suggest  PARTITION (dt = '"+DateTimeUtil.getYesterdayStr()+"') \n" +
                    "select mtShopId,eleShopId,shopName,similarty,categoryId from records ";
            System.out.println(spark.sql(sqlInsert).count());
        } catch (AnalysisException e) {
            System.out.println("create SimilartyData dataFrame failed");
            e.printStackTrace();
        }
        //Dataset<Row> resultSet = spark.createDataFrame(similartyDataset,SimilartyData.class);
        spark.sql("use platform_dw");
    }
}
 
 
2人点赞
 
 


作者:adam_go
链接:https://www.jianshu.com/p/52298e5e0473
 
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics