提问者:小点点

如何在Scalatest中创建一个在测试套件之间持续存在的共享SparkSession夹具?


我对Scala和Scalatest很熟悉,但对Pyspark有一些经验,我正在尝试从Spark的角度学习Scala。

目前,我正在尝试在Scalatest中正确地设置和使用夹具。

我想象这种工作方式,这可能不是在 Scala 中完成的方式,我会将 SparkSession 设置为在测试套件之间共享的全局装置,然后可能将几个示例数据集连接到该 SparkSession,可用于具有多个测试等的单个测试套件。

目前,我有一些代码正在使用BeforeAndAfterAll特性使用共享夹具在同一套件中运行多个测试;但是,如果我同时运行多个套件,首先完成的套件似乎会终止SparkSession,并且任何进一步的测试都会失败,并显示java.lang.IllegalStateException:无法在停止的SparkContext上调用方法

所以,我想知道是否有一种方法可以创建SparkSession,以便只有在所有运行套件完成后才会停止它;或者如果我找错了地方,有一个更好的方法——正如我所说,我对Scala非常陌生,所以这可能不是你这样做的方式,在这种情况下,非常欢迎替代建议。

首先,我有一个包testSetup,我正在为SparkSession创建一个特征:

package com.example.test

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

import org.scalatest._
import org.scalatest.FixtureSuite
import org.scalatest.funsuite.FixtureAnyFunSuite

package testSetup {

 trait SparkSetup  {
    val spark = SparkSession
      .builder
      .master("local")
      .appName(getClass.getSimpleName.replace("$", ""))
      .getOrCreate()
      
    spark.sparkContext.setLogLevel("ERROR")
  }
    

然后在特质中使用它来建立一些样本数据:

 trait TestData extends SparkSetup {

    def data(): DataFrame = {

      val testDataStruct = StructType(List(
                              StructField("date", StringType, true),
                              StructField("period", StringType, true),
                              StructField("ID", IntegerType, true),
                              StructField("SomeText", StringType, true)))

      val testData = Seq(Row("01012020", "10:00", 20, "Some Text"))

      spark.createDataFrame(spark.sparkContext.parallelize(testData), testDataStruct)
      
    }
  }

然后,我将这些放在一起,通过 withFixture 运行测试,并使用 afterAll 关闭 SparkSession;这显然是有些地方不太对劲,但我不太确定是什么:

  trait DataFixture extends funsuite.FixtureAnyFunSuite with TestData with BeforeAndAfterAll { this: FixtureSuite =>

    type FixtureParam = DataFrame

    def withFixture(test: OneArgTest) = {

      super.withFixture(test.toNoArgTest(data())) // "loan" the fixture to the test
    }

    override def afterAll() {
      spark.close()
    }        
  }
}

我目前正在测试一个基本函数来动态散列DataFrame中的列,并可以选择排除一些列;这是代码:

package com.example.utilities

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._

object GeneralTransforms {
    def addHashColumn(inputDataFrame: DataFrame, exclusionCols: List[String]): DataFrame = {
        
        val columnsToHash = inputDataFrame.columns.filterNot(exclusionCols.contains(_)).toList
        
        inputDataFrame.withColumn("RowHash", sha2(concat_ws("|", columnsToHash.map(col) : _*), 256))
    }
}

和当前的测试用例:-

import testSetup._
import com.example.utilities.GeneralTransforms._

import org.apache.spark.sql.DataFrame

class TestData extends funsuite.FixtureAnyFunSuite with DataFixture {
  test("Test data has correct columns") { inputData => 
    val cols = inputData.columns.toSeq
    val expectedCols = Array("date", "period", "ID", "SomeText").toSeq
    
    assert(cols == expectedCols)
  }
}

class TestAddHashColumn extends funsuite.FixtureAnyFunSuite with DataFixture {
  
  test("Test new hash column added") { inputData =>
    val hashedDf = addHashColumn(inputData, List())
    val initialCols = inputData.columns.toSeq
    val cols = hashedDf.columns.toSeq
    
    assert(initialCols.contains("RowHash") == false)
    assert(cols.contains("RowHash") == true)
  }

  test("Test all columns hashed - no exclusion") { inputData =>
    val hashedDf = addHashColumn(inputData, List())
    val rowHashColumn = hashedDf.select("RowHash").first().getString(0)
    val checkString = "01012020|10:00|20|Some Text"
    val expectedHash = String.format("%064x", new java.math.BigInteger(1, java.security.MessageDigest.getInstance("SHA-256").digest(checkString.getBytes("UTF-8"))))

    assert(rowHashColumn == expectedHash)
  }

  test("Test all columns hashed - with exclusion") { inputData =>

    val excludedColumns = List("ID", "SomeText")
    val hashedDf = addHashColumn(inputData,excludedColumns)
    val rowHashColumn = hashedDf.select("RowHash").first().getString(0)
    val checkString = "01012020|10:00"
    val expectedHash = String.format("%064x", new java.math.BigInteger(1, java.security.MessageDigest.getInstance("SHA-256").digest(checkString.getBytes("UTF-8"))))

    assert(rowHashColumn == expectedHash)

  }
}

这两个测试套件在隔离状态下都工作得非常好;只有当两者一起运行时,我才有问题。这也可以通过在我的build.sbt中添加< code > parallel execution in Test:= false 来解决,但是当我添加越来越多的测试时,如果能够允许这种情况并行发生就更好了。

我还想知道,这是不是可以通过在BeforeAll/AfterAll中运行一些检查SparkSession的其他实例的上下文来解决的,但我不确定如何做到这一点,我想在我进入另一个兔子洞之前先用尽这条路!

自从发帖以来,我在这上面花了更多的时间,并且做了一些改变,用一个助手类来处理Spark设置。在这里,我使用< code > spark session . builder . getor create 方法创建了一个伪主spark会话,然后为实际测试创建了一个新的spark会话——这将允许我拥有不同的配置,并执行不同的临时表注册等操作。然而,我仍然无法解决spark的关闭问题——显然,如果我对< code>SparkContext上的任何正在运行的会话运行spark.stop(),它将停止所有会话的上下文。

在sbt退出之前,上下文似乎没有停止?


共1个答案

匿名用户

当我试图解决我自己关于Spark和我为它编写的ScalaTest fixture的问题时,我偶然发现了这个问题,特别是我在我编写的类似于您的代码的fixture的< code>afterAll()中调用了< code>SparkSession.stop()。具体来说,是这样开始的堆栈跟踪:

java.lang.IllegalArgumentException: Error while instantiating 'org.apache.spark.sql.internal.SessionStateBuilder':
    at org.apache.spark.sql.SparkSession$.org$apache$spark$sql$SparkSession$$instantiateSessionState(SparkSession.scala:1178)
    at org.apache.spark.sql.SparkSession.$anonfun$sessionState$2(SparkSession.scala:162)
    at scala.Option.getOrElse(Option.scala:189)
    at org.apache.spark.sql.SparkSession.sessionState$lzycompute(SparkSession.scala:160)
    at org.apache.spark.sql.SparkSession.sessionState(SparkSession.scala:157)
    at org.apache.spark.sql.SparkSession$.conf$lzycompute$1(SparkSession.scala:1069)
    at org.apache.spark.sql.SparkSession$.conf$1(SparkSession.scala:1069)
    at org.apache.spark.sql.SparkSession$.applyModifiableSettings(SparkSession.scala:1072)
    at org.apache.spark.sql.SparkSession$Builder.getOrCreate(SparkSession.scala:942)
    at codes.lyndon.spark.test.SharedSparkSessionFixture.beforeAll(SharedSparkSessionFixture.scala:32)

通过查看代码,无法避免SparkSession.stop()函数调用停止父SparkContext并干扰其他会话(这是预期行为)。因此,最好不要在测试夹具中调用SparkSession.stop()。Spark为正在运行的JVM添加了一个用于安全关闭Spark上下文的关闭钩子,这样您的测试就不会挂起。

至于保留大量SparkSession实例的潜在问题,对实例的唯一直接引用是在SparkSession伴随对象中,并与当前运行的线程和默认会话相关。这些设置是在SparkSession.BuildergetOrCreate方法中设置的,或者由您自己手动设置。如果您担心,可以使用SparkSession.clearActiveSession()SparkSession.cclearDefaultSession()在夹具的afterAll()方法中明确清除这些引用。