Spark学习---SparkSQL(概述、编程、数据的加载和保存、自定义UDFA、项目实战)

本文涉及的产品
RDS MySQL Serverless 基础系列,0.5-2RCU 50GB
云数据库 RDS MySQL,集群系列 2核4GB
推荐场景:
搭建个人博客
云数据库 RDS MySQL,高可用系列 2核4GB
简介: Spark学习---SparkSQL(概述、编程、数据的加载和保存、自定义UDFA、项目实战)

一、SparkSQL概述

1.1 什么是SparkSQL

Spark是用于结构化数据处理的Spark模块。与基本的Spark RDD API不同,SparkSQL提供的接口为Spark提供了有关数据结构和正在执行的计算的更多信息。在内部,SparkSQL使用这些额外的信息来执行额外的优化。与SparkSQL交互的方式有很多种,包括SQL和DatasetAPI。结算时,使用相同的执行引擎,与你用于表计算的API/语言无关。

它提供了⼀个编程抽象叫做DataFrame/Dataset,它可以理解为⼀个基于RDD数据模型的更⾼级数据
模型,带有结构化元信息(schema),DataFrame其实就是Dataset[Row],Spark SQL可以将针对
DataFrame/Dataset的各类SQL运算,翻译成RDD的各类算⼦执⾏计划,从⽽⼤⼤简化数据运算编程
(请联想Hive)

1.2 为什么要有SparkSQL

image.png

SparkSQL的发展

1、发展历史

RDD(Spark1.0)=> Dateframe(Spark1.3) =>Dataset(Spark1.6)

如果同样的数据都给到这三个数据结构,它们分别计算之后,都会给出相同的结果。

不同的是它们执行效率和执行方式。在现在的版本中,dataset性能最好,已经成为了唯一使用的接口。其中Dataframe已经在底层被看作是特殊泛型的DataSet。

2、三者的共性

(1)RDD、DataFrame、DataSet全都是Spark平台下的分布式弹性数据集,为处理大型数据通过便利。

(2)三者都有惰性机制,在进行创建、转换,如map方法时,不会立即执行,只有在遇到Action行动算子实,三者才会开始遍历运算。

(3)三者有许多共同的函数,例如filter,sortby等

(4)三者都会根据Spark的内存情况自动缓存运算。

(5)三者都有分区的概念

SparkSQL的特点

1、易整合:无缝的整合了SQL查询和Spark编程

2、统一的数据访问方式:使用相同的方式连接不同的数据源

3、兼容Hive:在已有的仓库上直接运行SQL或者HQL

4、标准的数据连接:通过JDBC或者ODBC来连接

二、SparkSQL 编程

2.1 SparkSession 新的起始点

在老的版本中,SparkSQL提供两种SQL查询起始点:

(1) 一个叫SQLContext,用于Spark自己提供的SQL查询;

(2)一个叫HiveContext,用于连接Hive的查询。

SparkSession是Spark最新的SQL查询起始点,实质上是SQLContext和HiveContext的组合,所以在SQLContext和HiveContext上可用的API在SparkSession上同样是可以使用的。

SparkSession内部封装了SparkContext,所以计算实际上是由SparkContext完成的。当我们使用spark-shell的时候,Spark框架会自动的创建一个名称叫做Spark的SparkSession,就像我们以前可以自动获取到一个sc来表示SparkContext。

image.png

image.png

从JSON⽂件加载DataFrame

package org.example

import org.apache.spark.sql.{DataFrame, SparkSession}

object S04_DataFrame读取复杂json文件 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("json创建dataframe")
      .master("local")
      .config("spark.default.parallelism", 20)
      .getOrCreate()
    import spark.implicits._
    val df: DataFrame = spark.read.json("F:\\代码区\\sparkp\\datas\\c.txt")
    df.show(100,false)
    df.printSchema()
    df.createTempView("df1")
    //取每个人的mother
    spark.sql(
      """
        |select
        |family[0].name
        |from df1
        |""".stripMargin).show(100,false)

    spark.close()
  }

}


image.png


package org.example

import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{DataFrame, SparkSession}

object S05_DataFrame读取复杂json文件 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("json创建dataframe")
      .master("local")
      .config("spark.default.parallelism", 20)
      .getOrCreate()


    //解析出来的结果比较丑陋,info被认为了struct类型,而struct类型的成员变量是统一的
    //手动指定schema来改善上面的问题
    val scheam=StructType(Seq(
      StructField("id",DataTypes.LongType),
      StructField("info",DataTypes.createMapType(DataTypes.StringType,DataTypes.StringType))
    ))
    val df: DataFrame = spark.read.schema(scheam).json("F:\\代码区\\sparkp\\datas\\d.txt")
    df.show(100, false)
    df.printSchema()
        //找出有年龄属性的数据后,求平均值
    df.createTempView("f4")
    spark.sql(
      """
        |select
        |avg(info['age']) as avg_age
        |from f4
        |where info['age'] is not null
        |""".stripMargin).show(100,false)

    spark.close()
  }

}

image.png

image.png

从非结构化⽂件加载DataFrame

sparksql创建wordcount

package org.example

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
/*
* sparksql将输入数据视做非结构化数据读的时候
* 就是把整行内容当成一个字段(value:String)
* */
object S06_普通文本文件 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("普通文本创建dataframe")
      .master("local")
      .config("spark.default.parallelism", 20)
      .config("spark.sql.shuffle.partitions",1)
      .getOrCreate()
    //dataframe就是dataset的一个特例:DataSet[Row]
    val df: DataFrame = spark.read.text("F:\\代码区\\sparkp\\datas\\f.txt")
    df.show(100,false)
    df.printSchema()
    val ds: Dataset[String] = spark.read.textFile("F:\\代码区\\sparkp\\datas\\f.txt")
    ds.show(100,false)
    ds.printSchema()
    //wordcount
    df.createTempView("ff")
    val words=spark.sql(
      """
        |select
        |words,count(1) as cnt
        |from
        |(select
        | explode(split(value,'\\s+')) as words
        | from ff)o
        | group by words
        |""".stripMargin
    )
    words.show(100,false)
    words.printSchema()
    import spark.implicits._
    val ds1:DataFrame = df.flatMap(
      row => {
        val line: String = row.getAs[String]("value")
        line.split("\\s+")
      }
    ).toDF("word")
    ds1.createTempView("ds1")
    spark.sql(
      """
        |select
        | word,count(1) as cnt
        | from df2
        | group by word
        |""".stripMargin
    ).show()




    spark.close()

  }

}

image.png

image.png

从Parquet⽂件进⾏创建

Parquet⽂件是⼀种列式存储⽂件格式,⽂件⾃带schema描述信息,自我描述

package org.example

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}

object S07_Parque文件 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("普通文本创建dataframe")
      .master("local")
      .config("spark.default.parallelism", 20)
      .config("spark.sql.shuffle.partitions", 1)
      .getOrCreate()
    val schema=StructType(
      Seq(
        StructField("id",DataTypes.LongType),
        StructField("country",DataTypes.StringType),
        StructField("name",DataTypes.StringType),
        StructField("battle",DataTypes.DoubleType),
        StructField("age",DataTypes.LongType)
      )
    )
    val df: DataFrame = spark.read.option("header", "true").schema(schema).csv("F:\\代码区\\sparkp\\datas\\a.txt")
    df.write.parquet("datas/parquet/")
    df.write.parquet("data/orc/")
   //读取上面job生成的parque文件
   val df2: DataFrame = spark.read.parquet("datas/parquet")
    df2.show(100,false)
    df2.printSchema()
    //读取上面job生成job生成的orc文件 //orc也是自我描述的列式存储文件格式
    val df3: DataFrame = spark.read.parquet("data/orc")
    df3.show(100,false)
    df3.printSchema()
    spark.close()
  }

}

外部存储服务创建DF

(1)从JDBC连接数据库服务器进⾏创建dataframe

package org.example

import org.apache.spark.sql.{DataFrameReader, SaveMode, SparkSession}

object Spark09_Mysql {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("创建JDBC")
      .master("local")
      .config("spark.default.parallelism", 20)
      .getOrCreate()
    val df= spark.read.format("jdbc").option("url", "jdbc:mysql://localhost:3306/atguigudb")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "915425")
      .option("dbtable", "user").load()
//      .load().show
    //保存数据
    df.write.format("jdbc")
      .option("url","jdbc:mysql://localhost:3306/atguigudb")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "915425")
      .option("dbtable", "user1")
      .mode(SaveMode.Append).save()
    spark.stop()
  }

}
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")

//创建SparkSession对象
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

import spark.implicits._

//方式1:通用的load方法读取
spark.read.format("jdbc")
  .option("url", "jdbc:mysql://linux1:3306/spark-sql")
  .option("driver", "com.mysql.jdbc.Driver")
  .option("user", "root")
  .option("password", "123123")
  .option("dbtable", "user")
  .load().show


//方式2:通用的load方法读取 参数另一种形式
spark.read.format("jdbc")
  .options(Map("url"->"jdbc:mysql://linux1:3306/spark-sql?user=root&password=123123",
    "dbtable"->"user","driver"->"com.mysql.jdbc.Driver")).load().show

//方式3:使用jdbc方法读取
val props: Properties = new Properties()
props.setProperty("user", "root")
props.setProperty("password", "123123")
val df: DataFrame = spark.read.jdbc("jdbc:mysql://linux1:3306/spark-sql", "user", props)
df.show

//释放资源

写入数据

case class User2(name: String, age: Long)
。。。
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
 
//创建SparkSession对象
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
 
val rdd: RDD[User2] = spark.sparkContext.makeRDD(List(User2("lisi", 20), User2("zs", 30)))
val ds: Dataset[User2] = rdd.toDS
//方式1:通用的方式  format指定写出类型
 
//方式2:通过jdbc方法
val props: Properties = new Properties()
props.setProperty("user", "root")
props.setProperty("password", "123123")
ds.write.mode(SaveMode.Append).jdbc("jdbc:mysql://linux1:3306/spark-sql", "user", props)
 
//释放资源
spark.stop()

spark整合hive原理--访问hive元数据服务原理

bin/hive --service metastore 1>/dev/null 2>&1 &

既然具备了hive的功能,那么就可以执⾏⼀切hive中能执⾏的动作:
 建表
 show 
 建库
 show 
 alter表
 ……
只不过,此时看⻅的表是spark中集成的hive的本地元数据库中的表!
如果想让spark中集成的hive,看⻅你外部集群中的hive的表,只要修改配置:把spark端的hive的元
数据服务地址,指向外部集群中hive的元数据服务地址;
有两种指定办法:
 在spark端加⼊hive-site.xml ,⾥⾯配置 ⽬标元数据库 mysql的连接信息
这会使得spark中集成的hive直接访问mysql元数据库
 在spark端加⼊hive-site.xml ,⾥⾯配置 ⽬标hive的元数据服务器地址
这会使得spark中集成的hive通过外部独⽴的hive元数据服务来访问元数据库

image.png

#在hive中创建表
use  atguigudb
CREATE TABLE user (
    id int,
    name VARCHAR(50) NOT NULL,
    age INT  
)
row format delimited fields terminated by ','
stored as textfile
tblproperties(
"external"="true");
insert into user(id,name,age)
values (1,'sun',60),
(2,'jie',80),
(6,'ss',90);
#加载数据
load data local inpath '/root/x.txt' into table t_sparkset;
package org.example

import org.apache.spark.sql.SparkSession

object Spark10_创建hive表 {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession
    .builder()
    .appName(this.getClass.getSimpleName)
    .master("local[*]")
      .config("hive.metastore.uris","hdfs://node1:9083")
     // 启⽤hive⽀持,需要调⽤enableHiveSupport,还需要添加⼀个依赖 spark-hive
    // 默认sparksql内置了⾃⼰的hive
    // 如果程序能从classpath中加载到hive-site配置⽂件,那么它访问的hive元数据库就不是本地内
    // 如果程序能从classpath中加载到core-site配置⽂件,那么它访问的⽂件系统也不再是本地⽂件
    .enableHiveSupport()
    .getOrCreate()
    val res=spark.sql(
      """
        |select * from default.t_sparktset
        |""".stripMargin
    )
    res.show(100,false)
    //读取hive的分区表 并指定要读取的分区
    spark.sql(""" select * from t_acc_log where dt='2021-12-03' """).show(100,false)
    spark.read.table("t_acc_log").where("dt='2021-12-04'").show(100,false)
    spark.close()

  }

}

用户自定义函数

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承UserDefinedAggregateFunction来实现用户自定义弱类型聚合函数。从Spark3.0版本后,UserDefinedAggregateFunction已经不推荐使用了。可以统一采用强类型聚合函数

image.png

package org.example
import org.apache.spark.sql.{DataFrame, SparkSession}
object S11_SPARKSQL的UDF自定义函数 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("自定义函数demo")
      .master("local")
      .getOrCreate()


    val df: DataFrame = spark.read
      .option("header", "true")
      .csv("file:///F:\\代码区\\sparkp\\datas\\a.txt")
      .toDF("id", "country", "name", "battle", "age")

    df.createTempView("df")

    val func = (c: String, n: String) => {
      val firstName: String = n.substring(0, 1)
      val lastName: String = n.substring(1)
      firstName + c + lastName
    }

    // 往sparksql的catalog中,注册函数名
    spark.udf.register("qiguai", func)


    // id,country,name,battle,age
    val res: DataFrame = spark.sql(
      """
        |select
        |  id,country,name,battle,age,qiguai(country,name) as new_name
        |from df
        |
        |""".stripMargin)

    res.show(100, false)


    spark.close()

  }

}

自定义UDF实战案例

package org.example
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.{DataFrame, SparkSession}
object S11_SPARKQL的UDF自定义函数应用实战 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("自定义函数demo")
      .master("local")
      .getOrCreate()

    val schema = StructType(Seq(
      StructField("id", DataTypes.IntegerType),
      StructField("f1", DataTypes.IntegerType),
      StructField("f2", DataTypes.IntegerType),
      StructField("f3", DataTypes.IntegerType),
      StructField("gender", DataTypes.StringType),
    ))

    val sample: DataFrame = spark.read.schema(schema).csv("datas/stu/input/sample.txt")


    val schema2 = StructType(Seq(
      StructField("id", DataTypes.StringType),
      StructField("f1", DataTypes.IntegerType),
      StructField("f2", DataTypes.IntegerType),
      StructField("f3", DataTypes.IntegerType)
    ))
    val test: DataFrame = spark.read.schema(schema2).csv("datas/stu/input/test.txt")



    sample.createTempView("sample")
    test.createTempView("test")


    // 首先写一个普通的scala函数
    //接收两个数组,返回一个距离
    val dist = (arr1: Array[Int], arr2: Array[Int]) => {
      //欧式距离
      //每个测试点与距离点的距离
      arr1.zip(arr2).map(tp => Math.pow(tp._1 - tp._2, 2)).sum
    }

    spark.udf.register("dist", dist)

    val distDf = spark.sql(
      """
        |
        |select
        |
        |sample.id as sample_id,
        |sample.gender as sample_gender,
        |test.id,
        |dist(array(sample.f1,sample.f2,sample.f3),array(test.f1,test.f2,test.f3)) as dist
        |
        |from sample cross join test
        |
        |

        |""".stripMargin)
    distDf.createTempView("dist_df")


    // TODO 距离算好,后续逻辑纯sql可以解决:
    // TODO 找到每个测试人距离最近的3个样本人,看这3个洋本人中,哪种性别最多,结果就是这种性别
    /**
     * dist_df
     * +---------+-------------+---+------+
     * |sample_id|sample_gender|id |dist  |
     * +---------+-------------+---+------+
     * |1        |m            |a  |221.0 |
     * |1        |m            |b  |874.0 |
     * |2        |m            |a  |46.0  |
     * |2        |m            |b  |1389.0|
     * |3        |m            |a  |264.0 |
     * |3        |m            |b  |973.0 |
     * |4        |f            |a  |1406.0|
     * |4        |f            |b  |59.0  |
     * |5        |f            |a  |1668.0|
     * |5        |f            |b  |21.0  |
     * |6        |f            |a  |2001.0|
     * |6        |f            |b  |4.0   |
     * +---------+-------------+---+------+
     */
    spark.sql(
      """
        |select
        |  sample_id,
        |  sample_gender,
        |  id,
        |  rn
        |from (
        |select
        |  sample_id,
        |  sample_gender,
        |  id,
        |  row_number() over(partition by id order by dist) as rn
        |from dist_df ) o
        |where rn <=3

        |
        |
        |""".stripMargin).createTempView("knn")
    /**
     * knn
     * +---------+-------------+---+---+
     * |sample_id|sample_gender|id |rn |
     * +---------+-------------+---+---+
     * |6        |f            |b  |1  |
     * |5        |f            |b  |2  |
     * |4        |m            |b  |3  |
     * |2        |m            |a  |1  |
     * |1        |f            |a  |2  |
     * |3        |m            |a  |3  |
     * +---------+-------------+---+---+
     */

    val res = spark.sql(
      """
        |select
        |   id,
        |   if(sum(if(sample_gender='f',0,1))>=2,'male','female') as gender
        |from knn
        |group by id
        |
        |""".stripMargin)

    res.show(100, false)


    spark.close()

  }

}

java开发spark快速上手

package javaspark;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.*;
import scala.Tuple2;
import java.util.Arrays;
import java.util.Iterator;

public class JavaWordCount {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setMaster("local");
        conf.setAppName("java版wordcount");
        JavaSparkContext sc = new JavaSparkContext(conf);
        // 读数据
        JavaRDD<String> rdd = sc.textFile("F:\\代码区\\sparkp\\datas\\wordcount\\input\\a.txt");

        JavaRDD<String> rdd2 = rdd.flatMap(new FlatMapFunction<String, String>() {
            @Override
            public Iterator<String> call(String s) throws Exception {
                return Arrays.stream(s.split("\\s+")).iterator();
            }
        });
        JavaPairRDD<String, Integer> rdd3 = rdd2.mapToPair(new PairFunction<String, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(String word) throws Exception {
                return new Tuple2<>(word, 1);
            }
        });
        JavaPairRDD<String, Integer> res = rdd3.reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 + v2;
            }
        });
        res.foreach(new VoidFunction<Tuple2<String, Integer>>() {
            @Override
            public void call(Tuple2<String, Integer> stringIntegerTuple2) throws Exception {
                System.out.println(stringIntegerTuple2._1+","+stringIntegerTuple2._2);
            }
        });
//        res.foreach(new VoidFunction<Tuple2<String, Integer>>() {
//            @Override
//            public void call(Tuple2<String, Integer> stringIntegerTuple2) throws Exception {
//                System.out.println(stringIntegerTuple2._1+""+stringIntegerTuple2._2);
//            }
//        });
        sc.stop();

    }
}
package javaspark;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
public class JavaWordCount2 {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setMaster("local");
        conf.setAppName("java版wordcount");

        JavaSparkContext sc = new JavaSparkContext(conf);

        // 读数据
        JavaPairRDD<String, Integer> res = sc.textFile("F:\\代码区\\sparkp\\datas\\wordcount\\input\\a.txt")
                .flatMap(s -> Arrays.asList(s.split("\\s+")).iterator())
                .mapToPair(w -> new Tuple2<>(w, 1))
                .reduceByKey((v1, v2) -> v1 + v2);

        List<Tuple2<String, Integer>> lst = res.collect();
        System.out.println(lst);

        sc.stop();
    }
}

自定义UDFA工作逻辑

image.png

 UDAF - 弱类型

package org.example
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
object SparkSQL04_弱UDAF {
  def main(args: Array[String]): Unit = {
    val sparkConf: SparkConf = new SparkConf().setAppName("SparkSQL").setMaster("local[*]")

    val spark: SparkSession =
      SparkSession.builder()
        //.appName("SparkSQL")
        //.master("local[*]")
        .config(sparkConf)
        .getOrCreate()
    val df: DataFrame = spark.read.json("datas/user.json")
    df.createOrReplaceTempView("user")
        spark.udf.register("ageAvg",new MyAvgAggregator())
    spark.sql("select ageAvg(age) from user").show()
    spark.close()

    //写代码不管用不用都导入。
    import spark.implicits._
  }
  class  MyAvgAggregator extends UserDefinedAggregateFunction{
    //输入数据结构
    override def inputSchema: StructType = {
      StructType(Array(
        StructField("age",LongType)
      ))
    }
    //缓存区数据的结构

    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total",LongType),
          StructField("count",LongType)
        )
      )
    }
    //输出:函数计算结果的数据类型
    override def dataType: DataType = LongType
    //函数的稳定性
    override def deterministic: Boolean = true
    //缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit =
      {
//        buffer(0)=0L
//        buffer(1)=0L
        buffer.update(0,0L)
        buffer.update(1,0L)
      }
    //根据输入的值更新缓冲区
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0,buffer.getLong(0)+input.getLong(0))
      buffer.update(1,buffer.getLong(1)+1)
    }
    //合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
      buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
    }
    //计算平均值
    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0)/buffer.getLong(1)
    }
  }

}

UDAF - 强类型

package org.example
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
object SparkSQL04_UDAF {
  def main(args: Array[String]): Unit = {
    //新的起点: SparkSession
    //需求:计算平均工资

    val sparkConf: SparkConf = new SparkConf().setAppName("SparkSQL").setMaster("local[*]")

    val spark: SparkSession =
      SparkSession.builder()
        //.appName("SparkSQL")
        //.master("local[*]")
        .config(sparkConf)
        .getOrCreate()
    val df: DataFrame = spark.read.json("datas/user.json")
    df.createOrReplaceTempView("user")
    spark.udf.register("ageAvg",functions.udaf(new My_Avg_Aggregator()))
    spark.sql("select ageAvg(age) from user").show()

    //写代码不管用不用都导入。
//    import spark.implicits._
//
//    val df: DataFrame = spark.read.json("data/user.json")
//    val ds: Dataset[User] = df.as[User]
//
//    val my_Avg_Aggregator = new My_Avg_Aggregator
//    //将UDAF函数转化为查询的列对象
//    val column: TypedColumn[User, Double] = my_Avg_Aggregator.toColumn
//
//    ds.select(column).show()

    spark.stop()

  }

  case class User(id: Long, name: String, age: Long)

  case class AgeBuffer(var totalAge: Long, var totalCount: Long)

   /**
    * 自定义聚合函数类
   * 泛型:
   * IN: User
   * BUF: AgeBuffer
   * OUT: Double
   */
  class My_Avg_Aggregator extends Aggregator[Long, AgeBuffer, Long] {
     //初始值缓冲区的初始化
     override def zero: AgeBuffer = AgeBuffer(0L, 0L)

     //根据输入的数据更新缓冲区的数据
   

     override def reduce(b: AgeBuffer, a: Long): AgeBuffer = 
       {
         b.totalAge=b.totalAge+a
         b.totalCount=b.totalCount+1
         b
       }
       //合并缓冲区

     override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
       b1.totalAge=b1.totalAge+b2.totalAge
       b1.totalCount=b1.totalCount+b2.totalCount
       b1
     }
     //计算缓冲区

     override def finish(reduction: AgeBuffer): Long = {
       reduction.totalAge/reduction.totalCount
     }

     override def bufferEncoder: Encoder[AgeBuffer] = Encoders.product

     override def outputEncoder: Encoder[Long] = Encoders.scalaLong
   }

}

自定义UDFA实战案例

package org.example
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import util.BitMapUtil
object S11_SPARKQL的UDF自定义函数应用实战1 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("自定义UDAF")
      .master("local")
      .config("spark.sql.shuffle.partitions", 2)
      .enableHiveSupport()
      .getOrCreate()


    // 加载待处理数据
    val schema = StructType(Seq(
      StructField("id", DataTypes.IntegerType),
      StructField("province", DataTypes.StringType),
      StructField("city", DataTypes.StringType),
      StructField("region", DataTypes.StringType),
      StructField("pv", DataTypes.IntegerType),
    ))
    val df: DataFrame = spark.read.schema(schema).csv("F:\\代码区\\sparkp\\datas\\input\\data.csv")
    df.createTempView("df")

    // 注册自定义UDAF
    import org.apache.spark.sql.functions.udaf
    spark.udf.register("gen_bitmap", udaf(BitMapGenUDA))
    spark.udf.register("merge_bitmap", udaf(BitMapOrMergeUDAF))

    val card = (bmBytes: Array[Byte]) => {
      BitMapUtil.deSerBitMap(bmBytes).getCardinality
    }
    spark.udf.register("card_bm", card)


    // 按省市区统计pv总数和uv总数并保存到hive中
    val pcrReport = spark.sql(
      """
        |
        |select
        |province,
        |city,
        |region,
        |sum(pv) as pv_amt,
        |card_bm(gen_bitmap(id)) as uv_cnt,
        |gen_bitmap(id) as bitmap
        |
        |from df
        |group by province,city,region
        |
        |
        |""".stripMargin)
    pcrReport.write.saveAsTable("pcr_report")


    // 读hive中省市区报表,聚合出省市报表
    spark.sql(
      """
        |
        |select
        |province,
        |city,
        |sum(pv_amt)  as pv_amt,
        |card_bm(merge_bitmap(bitmap)) as uv_cnt,
        |merge_bitmap(bitmap) as bitmap
        |
        |from pcr_report
        |group by province,city
        |
        |""".stripMargin).show(100, false)


    // 读hive中省市区报表,聚合出省报表
    spark.sql(
      """
        |
        |select
        |province,
        |sum(pv_amt)  as pv_amt,
        |card_bm(merge_bitmap(bitmap)) as uv_cnt,
        |merge_bitmap(bitmap) as bitmap
        |
        |from pcr_report
        |group by province
        |
        |""".stripMargin).show(100, false)


    spark.close()

  }

}

image.png

SparkSQL项目实战

我们这次 Spark-sql 操作中所有的数据均来自 Hive,首先在 Hive 中创建表,,并导入数据。

一共有3张表: 1张用户行为表,1张城市表,1 张产品表

image.png

image.png

image.png

一共有3张表: 1张用户行为表,1张城市表,1 张产品表
CREATE TABLE `user_visit_action`(
  `date` string,  
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';
load data local inpath 'input/user_visit_action.txt' into table user_visit_action;

CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';
load data local inpath 'input/product_info.txt' into table product_info;

CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';
load data local inpath 'input/city_info.txt' into table city_info;

需求:各区域热门商品 Top3

需求简介

这里的热门商品是从点击量的维度来看的,计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。

例如:

地区

商品名称

点击次数

城市备注

华北

商品A

100000

北京21.2%,天津13.2%,其他65.6%

华北

商品P

80200

北京63.0%,太原10%,其他27.0%

华北

商品M

40000

北京63.0%,太原10%,其他27.0%

东北

商品J

92000

大连28%,辽宁17.0%,其他 55.0%


package org.example

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession

object SparkSQL_Req_2 {
  def main(args: Array[String]): Unit = {
    val sparkConf: SparkConf = new SparkConf().setAppName("SparkSql").setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder()
      .config(sparkConf).
      enableHiveSupport().
      getOrCreate()
    val sc: SparkContext = spark.sparkContext
    //准备数据
    spark.sql(
      """
        |select
        |  t3.area,
        |  t3.product_name,
        |  t3.p_click_count,
        |  t3.rk
        |from
        |  (
        |select
        |  t2.area,
        |  t2.product_name,
        |  t2.p_click_count,
        |  rank() over( partition by t2.area order by t2.p_click_count desc ) rk
        |from
        |  (
        |select
        |  t1.area,
        |  t1.product_name,
        |  count(t1.click_product_id) p_click_count
        |
        |from
        |  (
        |select
        |    u.click_product_id ,
        |    p.product_name,
        |    c.city_name,
        |    c.area
        |from
        |   user_visit_action  u
        |join
        |   product_info p
        |on
        |   u.click_product_id = p.product_id
        |join
        |   city_info c
        |on
        |   u.city_id  = c.city_id
        |where
        |   u.click_product_id != -1
        |  )t1
        |group by t1.area , t1.product_name
        |  )t2
        |  )t3
        |where t3.rk <=3
          """.stripMargin).show()

    spark.stop()
    
  }

}

 需求分析

 查询出来所有的点击记录,并与 city_info 表连接,得到每个城市所在的地区,与 Product_info 表连接得到产品名称

 按照地区和商品 id 分组,统计出每个商品在每个地区的总点击次数

 每个地区内按照点击次数降序排列

 只取前三名

 城市备注需要自定义 UDAF 函数

package org.example
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Aggregator

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object SparkSQL_Req_3 {
  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "atguigu")

    //新的起点: SparkSession

    val sparkConf: SparkConf = new SparkConf().setAppName("SparkSQL").setMaster("local[*]")

    val spark: SparkSession =
      SparkSession.builder()
        //.appName("SparkSQL")
        //.master("local[*]")
        .config(sparkConf)
        .enableHiveSupport() // 启用hive的支持
        .getOrCreate()

    var sc = spark.sparkContext
    //写代码不管用不用都导入。

    spark.udf.register("cityMark", functions.udaf(new CityMarkAgg))

    spark.sql(
      """
        |select
        |  t3.area,
        |  t3.product_name,
        |  t3.p_click_count,
        |  t3.c_m,
        |  t3.rk
        |from
        |  (
        |select
        |  t2.area,
        |  t2.product_name,
        |  t2.p_click_count,
        |  t2.c_m ,
        |  rank() over( partition by t2.area order by t2.p_click_count desc ) rk
        |from
        |  (
        |select
        |  t1.area,
        |  t1.product_name,
        |  count(t1.click_product_id) p_click_count,
        |  cityMark(t1.city_name) c_m
        |
        |from
        |  (
        |select
        |    u.click_product_id ,
        |    p.product_name,
        |    c.city_name,
        |    c.area
        |from
        |   user_visit_action  u
        |join
        |   product_info p
        |on
        |   u.click_product_id = p.product_id
        |join
        |   city_info c
        |on
        |   u.city_id  = c.city_id
        |where
        |   u.click_product_id != -1
        |  )t1
        |group by t1.area , t1.product_name
        |  )t2
        |  )t3
        |where t3.rk <=3
         """.stripMargin).show(50, false)

    spark.stop()

  }

  case class CityBuffer(var toatlClick: Long, var cityMap: mutable.Map[String, Long])

  /**
   * 自定义函数(强类型)
   * 1. 继承Aggregator
   * 2. 确定泛型:
   * IN:  String
   * BUF: CityBuffer
   * OUT: String
   */
  class CityMarkAgg extends Aggregator[String, CityBuffer, String] {

    override def zero: CityBuffer = CityBuffer(0L, mutable.Map[String, Long]())

    override def reduce(buffer: CityBuffer, cityName: String): CityBuffer = {
      //总点击次数加1
      buffer.toatlClick += 1

      //城市点击次数加1
      val old: Long = buffer.cityMap.getOrElse(cityName, 0L)
      buffer.cityMap.put(cityName, old + 1)

      buffer
    }

    override def merge(b1: CityBuffer, b2: CityBuffer): CityBuffer = {
      //总点击次数
      b1.toatlClick += b2.toatlClick

      //城市的点击次数
      for ((cityName, cityCount) <- b2.cityMap) {
        val old: Long = b1.cityMap.getOrElse(cityName, 0L)
        b1.cityMap.put(cityName, old + cityCount)
      }
      b1
    }

    override def finish(buffer: CityBuffer): String = {

      val result: ListBuffer[String] = ListBuffer[String]()


      //总点击次数
      val totalClick: Long = buffer.toatlClick
      //城市点击次数
      val citys: List[(String, Long)] = buffer.cityMap.toList.sortBy(_._2)(Ordering.Long.reverse).take(2)

      var totalPer: Double = 100L

      //处理点击比例
      for ((cityName, clickCount) <- citys) {
        val per: Double = clickCount * 100 / totalClick.toDouble
        totalPer -= per
        val cityMark: String = cityName + " " + per + "%"
        result.append(cityMark)
      }

      //处理其他
      if (buffer.cityMap.size > 2) {
        result.append(s"其他 $totalPer%")
      }

      result.mkString(", ")

    }

    override def bufferEncoder: Encoder[CityBuffer] = Encoders.product

    override def outputEncoder: Encoder[String] = Encoders.STRING
  }

}
相关实践学习
如何在云端创建MySQL数据库
开始实验后,系统会自动创建一台自建MySQL的 源数据库 ECS 实例和一台 目标数据库 RDS。
全面了解阿里云能为你做什么
阿里云在全球各地部署高效节能的绿色数据中心,利用清洁计算为万物互联的新世界提供源源不断的能源动力,目前开服的区域包括中国(华北、华东、华南、香港)、新加坡、美国(美东、美西)、欧洲、中东、澳大利亚、日本。目前阿里云的产品涵盖弹性计算、数据库、存储与CDN、分析与搜索、云通信、网络、管理与监控、应用服务、互联网中间件、移动服务、视频服务等。通过本课程,来了解阿里云能够为你的业务带来哪些帮助 &nbsp; &nbsp; 相关的阿里云产品:云服务器ECS 云服务器 ECS(Elastic Compute Service)是一种弹性可伸缩的计算服务,助您降低 IT 成本,提升运维效率,使您更专注于核心业务创新。产品详情: https://www.aliyun.com/product/ecs
相关文章
|
2月前
|
存储 分布式计算 算法
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
68 0
|
2月前
|
分布式计算 算法 Spark
spark学习之 GraphX—预测社交圈子
spark学习之 GraphX—预测社交圈子
45 0
|
2月前
|
分布式计算 Scala Spark
educoder的spark算子学习
educoder的spark算子学习
19 0
|
2月前
|
存储 分布式计算 算法
大数据-105 Spark GraphX 基本概述 与 架构基础 概念详解 核心数据结构
大数据-105 Spark GraphX 基本概述 与 架构基础 概念详解 核心数据结构
54 0
|
2月前
|
消息中间件 分布式计算 Kafka
大数据-98 Spark 集群 Spark Streaming 基础概述 架构概念 执行流程 优缺点
大数据-98 Spark 集群 Spark Streaming 基础概述 架构概念 执行流程 优缺点
44 0
|
2月前
|
SQL 分布式计算 大数据
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(一)
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(一)
57 0
|
2月前
|
SQL 分布式计算 算法
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(二)
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(二)
85 0
|
2月前
|
SQL 分布式计算 Java
大数据-96 Spark 集群 SparkSQL Scala编写SQL操作SparkSQL的数据源:JSON、CSV、JDBC、Hive
大数据-96 Spark 集群 SparkSQL Scala编写SQL操作SparkSQL的数据源:JSON、CSV、JDBC、Hive
47 0
|
存储 分布式计算 Spark
Spark会把数据都载入到内存么?
这篇文章算是个科普贴。如果已经熟悉Spark的就略过吧。
1893 0
|
存储 分布式计算 Spark
Spark会把数据都载入到内存么?
前言         很多初学者其实对Spark的编程模式还是RDD这个概念理解不到位,就会产生一些误解。   比如,很多时候我们常常以为一个文件是会被完整读入到内存,然后做各种变换,这很可能是受两个概念的误导:   RDD的定义,RDD是一个分布式的不可变数据集合   Spark 是一个内
2451 0