Spark外部数据源demo

时间:2023-01-03 19:06:35

一、创建Relation

package com.spark.datasource.demo;

import org.apache.spark.sql.sources._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.rdd.RDD
import java.sql.{ DriverManager, ResultSet }
import org.apache.spark.sql.{ Row, SQLContext }
import scala.collection.mutable.ArrayBuffer
import org.slf4j.LoggerFactory
import java.io._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import scala.collection.JavaConversions._


/** * implement user define dataSource need steps * 1.1 create DefaultSource extends RelationProvider . * class name must be DefaultSource * 1.2 implement user define Relation * Relation support 4 scanning strategies * <1> full table scan , need extend TableScan * <2> column scan , need extend PrunedScan * <3> column scan + filter row , need extend PrunedFilterScan * <4> CatalystScan * 1.3 implement user define RDD * 1.4 implement user define RddPatertion * 1.5 implement user define RDD Iterator */
class DefaultSource extends RelationProvider
    with SchemaRelationProvider with CreatableRelationProvider {

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext, parameters, null)
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
    return MyRelation(parameters, schema)(sqlContext)
  }

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    createRelation(sqlContext, parameters, data.schema)
  }
}

case class MyRelation(@transient val parameters: Map[String, String], @transient userSchema: StructType)(@transient val sqlContext: SQLContext)
    extends BaseRelation with TableScan with PrunedScan with PrunedFilteredScan with Serializable {

  private val logger = LoggerFactory.getLogger(getClass)
  private val sparkContext = sqlContext.sparkContext

  def printStackTraceStr(e: Exception, data: String) = {
    val sw: StringWriter = new StringWriter()
    val pw: PrintWriter = new PrintWriter(sw)
    e.printStackTrace(pw)
    println("======>>printStackTraceStr Exception: " + e.getClass() + "\n==>" + sw.toString() + "\n==>data=" + data)
  }

  override def schema: StructType = {
    if (this.userSchema != null) {
      return this.userSchema
    } else {
      return StructType(Seq(StructField("data", IntegerType)))
    }
  }

  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
    logger.info("unhandledFilters with filters " + filters.toList)
    // unhandled function return true spark deal with filter
    // otherwise data source deal with
    def unhandled(filter: Filter): Boolean = {
      filter match {
        case EqualTo(col, v) => {
          println("EqualTo col is :" + col + " value is :" + v)
          true
        }
        case _ => true
      }
    }
    filters.filter(unhandled)
  }

  override def buildScan(): RDD[Row] = {
     logger.info("Table Scan buildScan ")
     return new MyRDD[Row](sparkContext)
  }

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    logger.info("pruned build scan for columns " + requiredColumns.toList)
    return new MyRDD[Row](sparkContext)
  }

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    logger.info("prunedfilteredScan build scan for columns " + requiredColumns.toList + "with filters " + filters.toList)
    return new MyRDD[Row](sparkContext)
  }
}

二、创建RDD和Partition

package com.spark.datasource.demo

import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.Reporter
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import org.apache.spark.sql.{ Row, SQLContext }
import org.slf4j.LoggerFactory
import org.apache.spark.sql.types._
import scala.collection.JavaConversions._

case class MyPartition(index: Int) extends Partition {

}

class MyRDD[T: ClassTag](
    @transient private val _sc: SparkContext) extends RDD[T](_sc, Nil) {

  private val logger = LoggerFactory.getLogger(getClass)

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    logger.warn("call MyRDD compute function ")
    val currSplit = split.asInstanceOf[MyPartition]
    new MyIterator(currSplit,context)
  }

  override protected def getPartitions: Array[Partition] = {
    logger.warn("call MyRDD getPartitions function ")
    val partitions = new Array[Partition](1)
    partitions(0) = new MyPartition(1)
    partitions
  }

  override protected def getPreferredLocations(split: Partition): Seq[String] = {
    logger.warn("call MyRDD getPreferredLocations function")
    val currSplit = split.asInstanceOf[MyPartition]
    Seq("localhost")
  }
}

三、创建Iterator

package com.spark.datasource.demo

import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import org.apache.spark.sql.{ Row, SQLContext }
import org.slf4j.LoggerFactory
import org.apache.spark.sql.types._
import java.io._

class MyIterator[T: ClassTag](
    split: MyPartition,
    context: TaskContext) extends Iterator[T] {
  private val logger = LoggerFactory.getLogger(getClass)
  private val currSplit = split.asInstanceOf[MyPartition]
  private var index = 0 ;
  override def hasNext: Boolean = {
       if(index == 1) {
         return false
       }
       index = index + 1 
       return true
  }
  override def next(): T = {
    val r = Row(100000)
    r.asInstanceOf[T]
  }
}

四、Eclipse截图

Spark外部数据源demo

五、SBT目录结构

Spark外部数据源demo

build.sbt代码
name := "SparkDataSourceDemo"

version := "0.1"

organization := "com.spark.datasource.demo"

scalaVersion := "2.10.4"

libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.6.0" % "provided"

resolvers += "Spark Staging Repository" at "https://repository.apache.org/content/repositories/orgapachespark-1038/"

publishMavenStyle := true

publishTo := {
  val nexus = "https://oss.sonatype.org/"
  if (version.value.endsWith("SNAPSHOT"))
    Some("snapshots" at nexus + "content/repositories/snapshots")
  else
    Some("releases"  at nexus + "service/local/staging/deploy/maven2")
}

六、SBT打包命令

  1. 在build.sbt同目录下执行
    /usr/local/sbt/sbt package

七、测试运行

1.在build.sbt同目录下执行
/usr/local/spark/bin/spark-sql –jars target/scala-2.10/xclouddatasourcespark_2.10-0.1.jar

2.创建表语句
CREATE TEMPORARY TABLE test USING com.spark.datasource.demo OPTIONS ();

select * from test;