From dc84843e2d681181173911e0df52fe9155b3cc12 Mon Sep 17 00:00:00 2001
From: wyy566 <531938832@qq.com>
Date: Fri, 25 Nov 2022 17:01:02 +0800
Subject: [PATCH] 1)updata the version number to 2.2.0 2)add new algorithm:
random forest classifier (RFC), gradient boosting decision tree (GBDT),
decision tree (DT), decision tree bucket(DTB) and Word2Vec.
---
README.md | 6 +-
ml-accelerator/pom.xml | 4 +-
.../DecisionTreeClassifier.scala | 327 ++
.../ml/classification/GBTClassifier.scala | 482 ++
.../RandomForestClassifier.scala | 567 ++
.../ml/feature/DecisionTreeBucketizer.scala | 440 ++
.../apache/spark/ml/feature/Word2Vec.scala | 463 ++
.../ml/regression/DecisionTreeRegressor.scala | 331 ++
.../spark/ml/regression/GBTRegressor.scala | 428 ++
.../ml/regression/RandomForestRegressor.scala | 359 ++
.../spark/ml/tree/impl/DecisionForest.scala | 1360 +++++
.../ml/tree/impl/DecisionTreeBucket.scala | 1361 +++++
.../ml/tree/impl/DecisionTreeMetadata.scala | 252 +
.../ml/tree/impl/GradientBoostedTrees.scala | 761 +++
.../spark/ml/tree/impl/RandomForest.scala | 1361 +++++
.../ml/tree/impl/RandomForest4GBDTX.scala | 689 +++
.../spark/ml/tree/impl/RandomForestRaw.scala | 1337 +++++
.../org/apache/spark/ml/tree/treeModels.scala | 561 ++
.../org/apache/spark/ml/tree/treeParams.scala | 626 ++
.../apache/spark/mllib/feature/Word2Vec.scala | 666 +++
.../spark/mllib/tree/DecisionTree.scala | 289 +
ml-core/pom.xml | 25 +-
.../main/java/dev/ludovic/netlib/BLAS.java | 240 +
.../dev/ludovic/netlib/InstanceBuilder.java | 77 +
.../java/dev/ludovic/netlib/JavaBLAS.java | 33 +
.../java/dev/ludovic/netlib/NativeBLAS.java | 33 +
.../dev/ludovic/netlib/blas/AbstractBLAS.java | 1689 ++++++
.../java/dev/ludovic/netlib/blas/F2jBLAS.java | 241 +
.../java/dev/ludovic/netlib/blas/JNIBLAS.java | 201 +
.../dev/ludovic/netlib/blas/Java8BLAS.java | 5157 +++++++++++++++++
.../scala/org/apache/spark/ml/tree/Node.scala | 480 ++
.../spark/ml/tree/impl/BaggedPoint.scala | 142 +
.../tree/impl/DTFeatureStatsAggregator.scala | 109 +
.../tree/impl/GradientBoostedTreesCore.scala | 255 +
.../spark/ml/tree/impl/TreePointX.scala | 41 +-
.../spark/ml/tree/impl/TreePointY.scala | 326 +-
.../spark/mllib/feature/VocabWord.scala | 28 +
ml-kernel-client-core/pom.xml | 4 +-
ml-kernel-client/pom.xml | 4 +-
.../ml/feature/DecisionTreeBucketizer.scala | 261 +
.../tree/impl/GradientBoostedTreesUtil.scala | 76 +
.../apache/spark/ml/tree/impl/RFUtils.scala | 7 +-
.../spark/mllib/feature/Word2VecSGHS.scala | 36 +
pom.xml | 2 +-
44 files changed, 21914 insertions(+), 223 deletions(-)
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionForest.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest4GBDTX.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForestRaw.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
create mode 100644 ml-accelerator/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/BLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java
create mode 100644 ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java
create mode 100644 ml-core/src/main/scala/org/apache/spark/ml/tree/Node.scala
create mode 100644 ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
create mode 100644 ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DTFeatureStatsAggregator.scala
create mode 100644 ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
create mode 100644 ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala
create mode 100644 ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
create mode 100644 ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
create mode 100644 ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
diff --git a/README.md b/README.md
index 97493a2..eaf65d0 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@ Introduction
The machine learning algorithm library running on Kunpeng processors is an acceleration library that provides a rich set of high-level tools for machine learning algorithms. It is based on the original APIs of Apache [Spark 3.1.1](https://github.com/apache/spark/tree/v3.1.1). The acceleration library for greatly improves the computing power in big data scenarios.
-The library provides 5 machine learning algorithms: latent dirichlet allocation (LDA), prefix-projected pattern prowth (Prefix-Span), alternating least squares (ALS), K-nearest neighbors (KNN), Density-based spatial clustering of applicaitons with noise (DBSCAN). You can find the latest documentation on the project web page. This README file contains only basic setup instructions.
+The library provides 10 machine learning algorithms: latent dirichlet allocation (LDA), prefix-projected pattern prowth (Prefix-Span), alternating least squares (ALS), K-nearest neighbors (KNN), Density-based spatial clustering of applicaitons with noise (DBSCAN), random forest classifier (RFC), gradient boosting decision tree (GBDT), decision tree (DT), decision tree bucket(DTB) and Word2Vec. You can find the latest documentation on the project web page. This README file contains only basic setup instructions.
You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions.
@@ -21,9 +21,9 @@ Building And Packageing
mvn clean package
-(2) Obtain "boostkit-ml-core_2.12-2.1.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-core/target" directory.
+(2) Obtain "boostkit-ml-core_2.12-2.2.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-core/target" directory.
- Obtain "boostkit-ml-acc_2.12-2.1.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory.
+ Obtain "boostkit-ml-acc_2.12-2.2.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory.
Contribution Guidelines
diff --git a/ml-accelerator/pom.xml b/ml-accelerator/pom.xml
index 0af611b..800e7e1 100644
--- a/ml-accelerator/pom.xml
+++ b/ml-accelerator/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.1.0
+ 2.2.0
4.0.0
boostkit-ml-acc_2.12
- 2.1.0
+ 2.2.0
${project.artifactId}
Spark ml algo accelerator
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000..9c22cb1
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
+import org.apache.spark.ml.tree.impl.{DecisionForest, RandomForest}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@Since("1.4.0")
+class DecisionTreeClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ with DecisionTreeClassifierParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("dtc"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(
+ dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+ validateNumClasses(numClasses)
+ val instances = extractInstances(dataset, numClasses)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ require(!strategy.bootstrap, "DecisionTreeClassifier does not need bootstrap sampling")
+ instr.logNumClasses(numClasses)
+ instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
+ probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds)
+
+ val trees = DecisionForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+
+ trees.head.asInstanceOf[DecisionTreeClassificationModel]
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
+ subsamplingRate = 1.0)
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
+ /** Accessor for supported impurities: entropy, gini */
+ @Since("1.4.0")
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeClassifier = super.load(path)
+}
+
+/**
+ * Decision tree model (http://en.wikipedia.org/wiki/Decision_tree_learning) for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@Since("1.4.0")
+class DecisionTreeClassificationModel private[ml] (
+ @Since("1.4.0")override val uid: String,
+ @Since("1.4.0")override val rootNode: Node,
+ @Since("1.6.0")override val numFeatures: Int,
+ @Since("1.5.0")override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
+ with DecisionTreeModel with DecisionTreeClassifierParams with MLWritable with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ /**
+ * Construct a decision tree classification model.
+ *
+ * @param rootNode Root node of tree, with other nodes attached.
+ */
+ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+ this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
+
+ override def predict(features: Vector): Double = {
+ rootNode.predictImpl(features).prediction
+ }
+
+ @Since("3.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ val outputData = super.transform(dataset)
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => predictLeaf(features) }
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
+ } else {
+ outputData
+ }
+ }
+
+ @Since("3.0.0")
+ override def predictRaw(features: Vector): Vector = {
+ Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
+ .setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"DecisionTreeClassificationModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
+ s"numClasses=$numClasses, numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @note Feature importance for single decision trees can have high variance due to
+ * correlated predictor variables. Consider using a [[RandomForestClassifier]]
+ * to determine feature importance instead.
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
+
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
+ override private[spark] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeClassificationModel] =
+ new DecisionTreeClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeClassificationModel = super.load(path)
+
+ private[DecisionTreeClassificationModel]
+ class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val (nodeData, _) = NodeData.build(instance.rootNode, 0)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
+ }
+ }
+
+ private class DecisionTreeClassificationModelReader
+ extends MLReader[DecisionTreeClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): DecisionTreeClassificationModel = {
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val root = loadTreeNodes(path, metadata, sparkSession)
+ val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): DecisionTreeClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification,
+ s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+ s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
+ // Can't infer number of features from old model, so default to -1
+ new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
new file mode 100644
index 0000000..37d386b
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -0,0 +1,482 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.GradientBoostedTrees
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
+ * learning algorithm for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
+ *
+ * @note Multiclass labels are not currently supported.
+ */
+@Since("1.4.0")
+class GBTClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
+ with GBTClassifierParams with DefaultParamsWritable with Logging {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("gbtc"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ *
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = {
+ logWarning("GBTClassifier.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ // Parameters from GBTParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
+ // Parameters from GBTClassifierParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setLossType(value: String): this.type = set(lossType, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(
+ dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
+ val validateInstance = (instance: Instance) => {
+ val label = instance.label
+ require(label == 0 || label == 1, s"GBTClassifier was given" +
+ s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
+ s" GBTClassifier currently only supports binary classification.")
+ }
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance),
+ extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance))
+ } else {
+ (extractInstances(dataset, validateInstance), null)
+ }
+
+ val numClasses = 2
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol,
+ impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
+ minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds,
+ checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol, thresholds)
+ instr.logNumClasses(numClasses)
+
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val (doUseAcc, setUseAccFlag) = super.getDoUseAcc
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ if (setUseAccFlag) {
+ GradientBoostedTrees.runWithValidationX(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), doUseAcc, Some(instr))
+ } else {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), Some(instr))
+ }
+ } else {
+ if (setUseAccFlag) {
+ GradientBoostedTrees.runX(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy),
+ doUseAcc, Some(instr))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy),
+ Some(instr))
+ }
+ }
+ baseLearners.foreach(copyValues(_))
+
+ val numFeatures = baseLearners.head.numFeatures
+ instr.logNumFeatures(numFeatures)
+
+ new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
+
+ /** Accessor for supported loss settings: logistic */
+ @Since("1.4.0")
+ final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassifier = super.load(path)
+}
+
+/**
+ * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
+ * model for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ *
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ *
+ * @note Multiclass labels are not currently supported.
+ */
+@Since("1.6.0")
+class GBTClassificationModel private[ml](
+ @Since("1.6.0") override val uid: String,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double],
+ @Since("1.6.0") override val numFeatures: Int,
+ @Since("2.2.0") override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
+ with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
+
+ require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
+ s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ /**
+ * Construct a GBTClassificationModel
+ *
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ * @param numFeatures The number of features.
+ */
+ private[ml] def this(
+ uid: String,
+ _trees: Array[DecisionTreeRegressionModel],
+ _treeWeights: Array[Double],
+ numFeatures: Int) =
+ this(uid, _trees, _treeWeights, numFeatures, 2)
+
+ /**
+ * Construct a GBTClassificationModel
+ *
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+ @Since("1.6.0")
+ def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
+ this(uid, _trees, _treeWeights, -1, 2)
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
+
+ /**
+ * Number of trees in ensemble
+ */
+ @Since("2.0.0")
+ val getNumTrees: Int = trees.length
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = _treeWeights
+
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ val outputData = super.transform(dataset)
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => predictLeaf(features) }
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
+ } else {
+ outputData
+ }
+ }
+
+ override def predict(features: Vector): Double = {
+ // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
+ if (isDefined(thresholds)) {
+ super.predict(features)
+ } else {
+ if (margin(features) > 0.0) 1.0 else 0.0
+ }
+ }
+
+ @Since("3.0.0")
+ override def predictRaw(features: Vector): Vector = {
+ val prediction: Double = margin(features)
+ Vectors.dense(Array(-prediction, prediction))
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ dv.values(0) = loss.computeProbability(dv.values(0))
+ dv.values(1) = 1.0 - dv.values(0)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): GBTClassificationModel = {
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
+ extra).setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"GBTClassificationModel: uid = $uid, numTrees=$getNumTrees, numClasses=$numClasses, " +
+ s"numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * See `DecisionTreeClassificationModel.featureImportances`
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector =
+ TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false)
+
+ /** Raw prediction for the positive class. */
+ private def margin(features: Vector): Double = {
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
+ blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
+ }
+
+ // hard coded loss, which is not meant to be changed in the model
+ private val loss = getOldLossType
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
+ val data = extractInstances(dataset)
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
+ OldAlgo.Classification)
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+
+ private val numFeaturesKey: String = "numFeatures"
+ private val numTreesKey: String = "numTrees"
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassificationModel = super.load(path)
+
+ private[GBTClassificationModel]
+ class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+
+ val extraMetadata: JObject = Map(
+ numFeaturesKey -> instance.numFeatures,
+ numTreesKey -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
+ }
+ }
+
+ private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
+ val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
+
+ val trees = treesData.map {
+ case (treeMetadata, root) =>
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ treeMetadata.getAndSetParams(tree)
+ tree
+ }
+ require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+ val model = new GBTClassificationModel(metadata.uid,
+ trees, treeWeights, numFeatures)
+ // We ignore the impurity while loading models because in previous models it was wrongly
+ // set to gini (see SPARK-25959).
+ metadata.getAndSetParams(model, Some(List("impurity")))
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1,
+ numClasses: Int = 2): GBTClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent for each tree is null since there is no good way to set this.
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
+ }
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
new file mode 100644
index 0000000..f9ce62b
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -0,0 +1,567 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree.impl.RandomForest
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Random Forest learning algorithm for
+ * classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@Since("1.4.0")
+class RandomForestClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ with RandomForestClassifierParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("rfc"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ // Parameters from TreeEnsembleParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ // Parameters from RandomForestParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setBootstrap(value: Boolean): this.type = set(bootstrap, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(
+ dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr =>
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
+ val instances = extractInstances(dataset, numClasses)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+ strategy.bootstrap = $(bootstrap)
+
+ instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol,
+ rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins,
+ maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed,
+ subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap)
+
+ val trees = RandomForest
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .map(_.asInstanceOf[DecisionTreeClassificationModel])
+ trees.foreach(copyValues(_))
+
+ val numFeatures = trees.head.numFeatures
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+ createModel(dataset, trees, numFeatures, numClasses)
+ }
+
+ private def createModel(
+ dataset: Dataset[_],
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int): RandomForestClassificationModel = {
+ val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+ val rfSummary = if (numClasses <= 2) {
+ new BinaryRandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ } else {
+ new RandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ }
+ model.setSummary(Some(rfSummary))
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
+ /** Accessor for supported impurity settings: entropy, gini */
+ @Since("1.4.0")
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ @Since("1.4.0")
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ TreeEnsembleParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassifier = super.load(path)
+}
+
+/**
+ * Random Forest model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ *
+ * @param _trees Decision trees in the ensemble.
+ * Warning: These have null parents.
+ */
+@Since("1.4.0")
+class RandomForestClassificationModel private[ml] (
+ @Since("1.5.0") override val uid: String,
+ private val _trees: Array[DecisionTreeClassificationModel],
+ @Since("1.6.0") override val numFeatures: Int,
+ @Since("1.5.0") override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
+ with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable
+ with HasTrainingSummary[RandomForestClassificationTrainingSummary] {
+
+ require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
+
+ /**
+ * Construct a random forest classification model, with all trees weighted equally.
+ *
+ * @param trees Component trees
+ */
+ private[ml] def this(
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = _treeWeights
+
+ /**
+ * Gets summary of model on training set. An exception is thrown
+ * if `hasSummary` is false.
+ */
+ @Since("3.1.0")
+ override def summary: RandomForestClassificationTrainingSummary = super.summary
+
+ /**
+ * Gets summary of model on training set. An exception is thrown
+ * if `hasSummary` is false or it is a multiclass model.
+ */
+ @Since("3.1.0")
+ def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match {
+ case b: BinaryRandomForestClassificationTrainingSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
+ s"(numClasses=${numClasses}), use summary instead.")
+ }
+
+ /**
+ * Evaluates the model on a test dataset.
+ *
+ * @param dataset Test dataset to evaluate model on.
+ */
+ @Since("3.1.0")
+ def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
+ if (numClasses > 2) {
+ new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
+ predictionColName, $(labelCol), weightColName)
+ } else {
+ new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), weightColName)
+ }
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ val outputData = super.transform(dataset)
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => predictLeaf(features) }
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
+ } else {
+ outputData
+ }
+ }
+
+ @Since("3.0.0")
+ override def predictRaw(features: Vector): Vector = {
+ // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
+ // Classifies using majority votes.
+ // Ignore the tree weights since all are 1.0 for now.
+ val votes = Array.ofDim[Double](numClasses)
+ _trees.foreach { tree =>
+ val classCounts = tree.rootNode.predictImpl(features).impurityStats.stats
+ val total = classCounts.sum
+ if (total != 0) {
+ var i = 0
+ while (i < numClasses) {
+ votes(i) += classCounts(i) / total
+ i += 1
+ }
+ }
+ }
+ Vectors.dense(votes)
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): RandomForestClassificationModel = {
+ copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
+ .setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"RandomForestClassificationModel: uid=$uid, numTrees=$getNumTrees, numClasses=$numClasses, " +
+ s"numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see `DecisionTreeClassificationModel.featureImportances`
+ */
+ @Since("1.5.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestClassificationModel] =
+ new RandomForestClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassificationModel = super.load(path)
+
+ private[RandomForestClassificationModel]
+ class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
+ }
+ }
+
+ private class RandomForestClassificationModelReader
+ extends MLReader[RandomForestClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): RandomForestClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ treeMetadata.getAndSetParams(tree)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ numFeatures: Int = -1): RandomForestClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent for each tree is null since there is no good way to set this.
+ DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
+ }
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
+ new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
+ }
+}
+
+/**
+ * Abstraction for multiclass RandomForestClassification results for a given model.
+ */
+sealed trait RandomForestClassificationSummary extends ClassificationSummary {
+ /**
+ * Convenient method for casting to BinaryRandomForestClassificationSummary.
+ * This method will throw an Exception if the summary is not a binary summary.
+ */
+ @Since("3.1.0")
+ def asBinary: BinaryRandomForestClassificationSummary = this match {
+ case b: BinaryRandomForestClassificationSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot cast to a binary summary.")
+ }
+}
+
+/**
+ * Abstraction for multiclass RandomForestClassification training results.
+ */
+sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
+ with TrainingSummary
+
+/**
+ * Abstraction for BinaryRandomForestClassification results for a given model.
+ */
+sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
+
+/**
+ * Abstraction for BinaryRandomForestClassification training results.
+ */
+sealed trait BinaryRandomForestClassificationTrainingSummary extends
+ BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
+
+/**
+ * Multiclass RandomForestClassification training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class RandomForestClassificationTrainingSummaryImpl(
+ predictions: DataFrame,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String,
+ override val objectiveHistory: Array[Double])
+ extends RandomForestClassificationSummaryImpl(
+ predictions, predictionCol, labelCol, weightCol)
+ with RandomForestClassificationTrainingSummary
+
+/**
+ * Multiclass RandomForestClassification results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ */
+private class RandomForestClassificationSummaryImpl(
+ @transient override val predictions: DataFrame,
+ override val predictionCol: String,
+ override val labelCol: String,
+ override val weightCol: String)
+ extends RandomForestClassificationSummary
+
+/**
+ * Binary RandomForestClassification training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the probability of each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class BinaryRandomForestClassificationTrainingSummaryImpl(
+ predictions: DataFrame,
+ scoreCol: String,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String,
+ override val objectiveHistory: Array[Double])
+ extends BinaryRandomForestClassificationSummaryImpl(
+ predictions, scoreCol, predictionCol, labelCol, weightCol)
+ with BinaryRandomForestClassificationTrainingSummary
+
+/**
+ * Binary RandomForestClassification for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the prediction of
+ * each class as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ */
+private class BinaryRandomForestClassificationSummaryImpl(
+ predictions: DataFrame,
+ override val scoreCol: String,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String)
+ extends RandomForestClassificationSummaryImpl(
+ predictions, predictionCol, labelCol, weightCol)
+ with BinaryRandomForestClassificationSummary
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
new file mode 100644
index 0000000..f774acd
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.DecisionTreeBucket
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+
+
+private[ml] trait DecisionTreeBucketizerParams extends Params
+ with DecisionTreeClassifierParams with HasWeightCol {
+
+ /**
+ * Param for bucketedFeatures column name.
+ * @group param
+ */
+ final val bucketedFeaturesCol: Param[String] =
+ new Param[String](this, "bucketedFeaturesCol", "bucketedFeatures column name")
+
+ final val prune: BooleanParam =
+ new BooleanParam(this, "prune", "if true, the algorithm will prune decision trees")
+
+ setDefault(bucketedFeaturesCol, "bucketedFeatures")
+ setDefault(prune, true)
+
+ /** @group getParam */
+ final def getBucketedFeaturesCol: String = $(bucketedFeaturesCol)
+
+ /** @group getParam */
+ final def getPrune: Boolean = $(prune)
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * copy from [[org.apache.spark.ml.PredictorParams]].validateAndTransformSchema
+ *
+ * @param schema input schema
+ * @param fitting whether this is in fitting
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ if (fitting) {
+ SchemaUtils.checkNumericType(schema, $(labelCol))
+
+ this match {
+ case p: HasWeightCol =>
+ if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+ SchemaUtils.checkNumericType(schema, $(p.weightCol))
+ }
+ case _ =>
+ }
+ }
+ SchemaUtils.appendColumn(schema, $(bucketedFeaturesCol), new VectorUDT)
+ }
+
+}
+
+/**
+ * Decision tree bucketing algorithm for data discretization.
+ */
+@Since("1.4.0")
+class DecisionTreeBucketizer @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends Estimator[DecisionTreeBucketModel]
+ with DecisionTreeBucketizerParams with DecisionTreeClassifierParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("dtb"))
+
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ def setBucketedFeaturesCol(value: String): this.type = set(bucketedFeaturesCol, value)
+
+ def setPrune(value: Boolean): this.type = set(prune, value)
+
+ // Override parameter setters from parent trait for Java API compatibility.
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+ * such as in `extractLabeledPoints()`.
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ private[ml] def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
+ MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+ case Some(n: Int) => n
+ case None =>
+ // Get number of classes from dataset itself.
+ val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
+ if (maxLabelRow.isEmpty || maxLabelRow.head.isNullAt(0)) {
+ throw new SparkException("ML algorithm was given empty dataset.")
+ }
+ val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+ require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
+ s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
+ val numClasses = maxDoubleLabel.toInt + 1
+ require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
+ s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
+ s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
+ s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
+ s" StringIndexer to the label column.")
+ logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
+ s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
+ numClasses
+ }
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, true)
+ }
+
+ override def fit(dataset: Dataset[_]): DecisionTreeBucketModel = {
+ // This handles a few items such as schema validation.
+ // Developers only need to implement train().
+ transformSchema(dataset.schema, logging = true)
+
+ // Cast LabelCol to DoubleType and keep the metadata.
+ val labelMeta = dataset.schema($(labelCol)).metadata
+ val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+ // Cast WeightCol to DoubleType and keep the metadata.
+ val casted = this match {
+ case p: HasWeightCol =>
+ if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+ val weightMeta = dataset.schema($(p.weightCol)).metadata
+ labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
+ } else {
+ labelCasted
+ }
+ case _ => labelCasted
+ }
+
+ copyValues(train(casted).setParent(this))
+ }
+
+ private[ml] def train(dataset: Dataset[_]): DecisionTreeBucketModel = instrumented { instr =>
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = getNumClasses(dataset)
+ val instances = extractInstances(dataset, numClasses)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
+ probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds)
+
+ logInfo(s"prune: ${$(prune)}")
+ val trees = DecisionTreeBucket.run(instances, strategy, getSeed, Some(instr), prune = $(prune))
+ .map(_.asInstanceOf[DecisionTreeClassificationModel])
+ trees.foreach(copyValues(_))
+
+ val numFeatures = trees.head.numFeatures
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+ val m = new DecisionTreeBucketModel(uid, trees, numFeatures, numClasses)
+ copyValues(m)
+ m
+ }
+
+ /** (private[ml]) Train decision trees on an RDD */
+ private[ml] def train(data: RDD[LabeledPoint],
+ oldStrategy: OldStrategy): DecisionTreeBucketModel = instrumented { instr =>
+ instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
+ probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds)
+
+ logInfo(s"prune: ${$(prune)}")
+ val trees = DecisionTreeBucket.run(data.map(_.toInstance), oldStrategy, getSeed, Some(instr),
+ prune = $(prune))
+ .map(_.asInstanceOf[DecisionTreeClassificationModel])
+
+ val numFeatures = data.first().features.size
+ val m = new DecisionTreeBucketModel(uid, trees, numFeatures, oldStrategy.numClasses)
+ copyValues(m)
+ m
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
+ subsamplingRate = 1.0)
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): DecisionTreeBucketizer = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object DecisionTreeBucketizer extends DefaultParamsReadable[DecisionTreeBucketizer] {
+ /** Accessor for supported impurities: entropy, gini */
+ @Since("1.4.0")
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeBucketizer = super.load(path)
+}
+
+/**
+ * Decision tree bucket model for data discretization.
+ * @param _trees Decision trees of all features.
+ */
+@Since("1.4.0")
+class DecisionTreeBucketModel private[ml] (
+ @Since("1.5.0") override val uid: String,
+ private val _trees: Array[DecisionTreeClassificationModel],
+ val numFeatures: Int,
+ val numClasses: Int)
+ extends Model[DecisionTreeBucketModel]
+ with DecisionTreeBucketizerParams with DecisionTreeClassifierParams
+ with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
+
+ require(_trees.nonEmpty, "DecisionTreeBucketModel requires at least 1 tree.")
+
+ /**
+ * Construct a decision tree bucket model, with all trees weighted equally.
+ *
+ * @param trees Component trees
+ */
+ private[ml] def this(
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int) =
+ this(Identifiable.randomUID("dtb"), trees, numFeatures, numClasses)
+
+ def getNumTrees: Int = _trees.length
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, false)
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+
+ // val outputData = super.transform(dataset)
+ val leafUDF = udf { features: Vector => predictLeaf(features) }
+ dataset.withColumn($(bucketedFeaturesCol), leafUDF(col($(featuresCol))))
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): DecisionTreeBucketModel = {
+ copyValues(new DecisionTreeBucketModel(uid, _trees, numFeatures, numClasses), extra)
+ .setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"DecisionTreeBucketModel (uid=$uid) with $getNumTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeBucketModel.DecisionTreeBucketModelWriter(this)
+}
+
+@Since("2.0.0")
+object DecisionTreeBucketModel extends MLReadable[DecisionTreeBucketModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeBucketModel] =
+ new DecisionTreeBucketModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeBucketModel = super.load(path)
+
+ private[DecisionTreeBucketModel]
+ class DecisionTreeBucketModelWriter(instance: DecisionTreeBucketModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
+ }
+ }
+
+ private class DecisionTreeBucketModelReader
+ extends MLReader[DecisionTreeBucketModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeBucketModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): DecisionTreeBucketModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ treeMetadata.getAndSetParams(tree)
+ tree
+ }
+ require(numTrees == trees.length, s"DecisionTreeBucketModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new DecisionTreeBucketModel(metadata.uid, trees, numFeatures, numClasses)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: DecisionTreeBucketizer,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ numFeatures: Int = -1): DecisionTreeBucketModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to DecisionTreeBucketModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent for each tree is null since there is no good way to set this.
+ DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
+ }
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtb")
+ new DecisionTreeBucketModel(uid, newTrees, numFeatures, numClasses)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
new file mode 100644
index 0000000..2bd98e6
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -0,0 +1,463 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.{Utils, VersionUtils}
+
+/**
+ * Params for [[Word2Vec]] and [[Word2VecModel]].
+ */
+private[feature] trait Word2VecBase extends Params
+ with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed {
+
+ /**
+ * The dimension of the code that you want to transform from words.
+ * Default: 100
+ * @group param
+ */
+ final val vectorSize = new IntParam(
+ this, "vectorSize", "the dimension of codes after transforming from words (> 0)",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getVectorSize: Int = $(vectorSize)
+
+ /**
+ * The window size (context words from [-window, window]).
+ * Default: 5
+ * @group expertParam
+ */
+ final val windowSize = new IntParam(
+ this, "windowSize", "the window size (context words from [-window, window]) (> 0)",
+ ParamValidators.gt(0))
+
+ /** @group expertGetParam */
+ def getWindowSize: Int = $(windowSize)
+
+ /**
+ * Number of partitions for sentences of words.
+ * Default: 1
+ * @group param
+ */
+ final val numPartitions = new IntParam(
+ this, "numPartitions", "number of partitions for sentences of words (> 0)",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getNumPartitions: Int = $(numPartitions)
+
+ /**
+ * The minimum number of times a token must appear to be included in the word2vec model's
+ * vocabulary.
+ * Default: 5
+ * @group param
+ */
+ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
+ "appear to be included in the word2vec model's vocabulary (>= 0)", ParamValidators.gtEq(0))
+
+ /** @group getParam */
+ def getMinCount: Int = $(minCount)
+
+ /**
+ * Sets the maximum length (in words) of each sentence in the input data.
+ * Any sentence longer than this threshold will be divided into chunks of
+ * up to `maxSentenceLength` size.
+ * Default: 1000
+ * @group param
+ */
+ final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " +
+ "(in words) of each sentence in the input data. Any sentence longer than this threshold will " +
+ "be divided into chunks up to the size (> 0)", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getMaxSentenceLength: Int = $(maxSentenceLength)
+
+ /**
+ * Sets the regularization coefficient.
+ * Default: 0.05f
+ * @group param
+ */
+ final val regularization = new FloatParam(this, "regularization", "Regularization coefficient")
+
+ /** @group getParam */
+ def getRegularization: Float = $(regularization)
+
+ /**
+ * Sets the number of repetitions of data.
+ * Default: 3
+ * @group param
+ */
+ final val repetition = new IntParam(this, "repetition", "The number of repetitions of data",
+ ParamValidators.gtEq(0))
+
+ /** @group getParam */
+ def getRepetition: Int = $(repetition)
+
+ setDefault(vectorSize -> 100, windowSize -> 5, numPartitions -> 1, minCount -> 5,
+ maxSentenceLength -> 1000, stepSize -> 0.025, maxIter -> 1, regularization -> 0.05f,
+ repetition -> 3)
+
+ /**
+ * Validate and transform the input schema.
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
+ SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
+ * natural language processing or machine learning process.
+ */
+@Since("1.4.0")
+final class Word2Vec @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends Estimator[Word2VecModel] with Word2VecBase with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("w2v"))
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setVectorSize(value: Int): this.type = set(vectorSize, value)
+
+ /** @group expertSetParam */
+ @Since("1.6.0")
+ def setWindowSize(value: Int): this.type = set(windowSize, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setNumPartitions(value: Int): this.type = set(numPartitions, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinCount(value: Int): this.type = set(minCount, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value)
+
+ /** @group setParam */
+ def setRegularization(value: Float): this.type = set(regularization, value)
+
+ /** @group setParam */
+ def setRepetition(value: Int): this.type = set(repetition, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Word2VecModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input =
+ dataset.select($(inputCol)).rdd.map(_.getSeq[String](0))
+ val wordVectors = new feature.Word2Vec()
+ .setLearningRate($(stepSize))
+ .setMinCount($(minCount))
+ .setNumIterations($(maxIter))
+ .setNumPartitions($(numPartitions))
+ .setSeed($(seed))
+ .setVectorSize($(vectorSize))
+ .setWindowSize($(windowSize))
+ .setMaxSentenceLength($(maxSentenceLength))
+ .setRegularization($(regularization))
+ .setRepetition($(repetition))
+ .fit(input)
+ copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
+}
+
+@Since("1.6.0")
+object Word2Vec extends DefaultParamsReadable[Word2Vec] {
+
+ @Since("1.6.0")
+ override def load(path: String): Word2Vec = super.load(path)
+}
+
+/**
+ * Model fitted by [[Word2Vec]].
+ */
+@Since("1.4.0")
+class Word2VecModel private[ml] (
+ @Since("1.4.0") override val uid: String,
+ @transient private val wordVectors: feature.Word2VecModel)
+ extends Model[Word2VecModel] with Word2VecBase with MLWritable {
+
+ import Word2VecModel._
+
+ /**
+ * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
+ * and the vector the DenseVector that it is mapped to.
+ */
+ @Since("1.5.0")
+ @transient lazy val getVectors: DataFrame = {
+ val spark = SparkSession.builder().getOrCreate()
+ val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
+ spark.createDataFrame(wordVec.toSeq).toDF("word", "vector")
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
+ * similarities between the synonyms and the given word.
+ */
+ @Since("1.5.0")
+ def findSynonyms(word: String, num: Int): DataFrame = {
+ val spark = SparkSession.builder().getOrCreate()
+ spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity")
+ }
+
+ /**
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
+ * similarities between the synonyms and the given word vector.
+ */
+ @Since("2.0.0")
+ def findSynonyms(vec: Vector, num: Int): DataFrame = {
+ val spark = SparkSession.builder().getOrCreate()
+ spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity")
+ }
+
+ /**
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(vec, num)
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(word, num)
+ }
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /**
+ * Transform a sentence column to a vector column to represent the whole sentence. The transform
+ * is performed by averaging all word vectors it contains.
+ */
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+ val vectors = wordVectors.getVectors
+ .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
+ .map(identity).toMap // mapValues doesn't return a serializable map (SI-7005)
+ val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
+ val d = $(vectorSize)
+ val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray)
+ val word2Vec = udf { sentence: Seq[String] =>
+ if (sentence.isEmpty) {
+ emptyVec
+ } else {
+ val sum = Vectors.zeros(d)
+ sentence.foreach { word =>
+ bVectors.value.get(word).foreach { v =>
+ BLAS.axpy(1.0, v, sum)
+ }
+ }
+ BLAS.scal(1.0 / sentence.size, sum)
+ sum
+ }
+ }
+ dataset.withColumn($(outputCol), word2Vec(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), $(vectorSize))
+ }
+ outputSchema
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): Word2VecModel = {
+ val copied = new Word2VecModel(uid, wordVectors)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ @Since("1.6.0")
+ override def write: MLWriter = new Word2VecModelWriter(this)
+
+ @Since("3.0.0")
+ override def toString: String = {
+ s"Word2VecModel: uid=$uid, numWords=${wordVectors.wordIndex.size}, " +
+ s"vectorSize=${$(vectorSize)}"
+ }
+}
+
+@Since("1.6.0")
+object Word2VecModel extends MLReadable[Word2VecModel] {
+
+ private[Word2VecModel] case class Data(word: String, vector: Array[Float])
+
+ private[Word2VecModel]
+ class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+
+ val wordVectors = instance.wordVectors.getVectors
+ val dataPath = new Path(path, "data").toString
+ val bufferSizeInBytes = Utils.byteStringAsBytes(
+ sc.conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "64m"))
+ val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
+ bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize)
+ val spark = sparkSession
+ import spark.implicits._
+ spark.createDataset[(String, Array[Float])](wordVectors.toSeq)
+ .repartition(numPartitions)
+ .map { case (word, vector) => Data(word, vector) }
+ .toDF()
+ .write
+ .parquet(dataPath)
+ }
+ }
+
+ private[feature]
+ object Word2VecModelWriter {
+ /**
+ * Calculate the number of partitions to use in saving the model.
+ * [SPARK-11994] - We want to partition the model in partitions smaller than
+ * spark.kryoserializer.buffer.max
+ * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max
+ * @param numWords Vocab size
+ * @param vectorSize Vector length for each word
+ */
+ def calculateNumberOfPartitions(
+ bufferSizeInBytes: Long,
+ numWords: Int,
+ vectorSize: Int): Int = {
+ val floatSize = 4L // Use Long to help avoid overflow
+ val averageWordSize = 15
+ // Calculate the approximate size of the model.
+ // Assuming an average word size of 15 bytes, the formula is:
+ // (floatSize * vectorSize + 15) * numWords
+ val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
+ val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
+ require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
+ s"partitions to save this model, which is too large. Try increasing " +
+ s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
+ numPartitions.toInt
+ }
+ }
+
+ private class Word2VecModelReader extends MLReader[Word2VecModel] {
+
+ private val className = classOf[Word2VecModel].getName
+
+ override def load(path: String): Word2VecModel = {
+ val spark = sparkSession
+ import spark.implicits._
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+
+ val dataPath = new Path(path, "data").toString
+
+ val oldModel = if (major < 2 || (major == 2 && minor < 2)) {
+ val data = spark.read.parquet(dataPath)
+ .select("wordIndex", "wordVectors")
+ .head()
+ val wordIndex = data.getAs[Map[String, Int]](0)
+ val wordVectors = data.getAs[Seq[Float]](1).toArray
+ new feature.Word2VecModel(wordIndex, wordVectors)
+ } else {
+ val wordVectorsMap = spark.read.parquet(dataPath).as[Data]
+ .collect()
+ .map(wordVector => (wordVector.word, wordVector.vector))
+ .toMap
+ new feature.Word2VecModel(wordVectorsMap)
+ }
+
+ val model = new Word2VecModel(metadata.uid, oldModel)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[Word2VecModel] = new Word2VecModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): Word2VecModel = super.load(path)
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000..825bb85
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
+import org.apache.spark.ml.tree.impl.{DecisionForest, RandomForest}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Decision tree
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@Since("1.4.0")
+class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+ extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+ with DecisionTreeRegressorParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("dtr"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setVarianceCol(value: String): this.type = set(varianceCol, value)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(
+ dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val instances = extractInstances(dataset)
+ val strategy = getOldStrategy(categoricalFeatures)
+ require(!strategy.bootstrap, "DecisionTreeRegressor does not need bootstrap sampling")
+
+ instr.logPipelineStage(this)
+ instr.logDataset(instances)
+ instr.logParams(this, params: _*)
+
+ val trees = DecisionForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+
+ trees.head.asInstanceOf[DecisionTreeRegressionModel]
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
+ subsamplingRate = 1.0)
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
+ /** Accessor for supported impurities: variance */
+ final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressor = super.load(path)
+}
+
+/**
+ *
+ * Decision tree (Wikipedia) model for regression.
+ * It supports both continuous and categorical features.
+ *
+ * @param rootNode Root of the decision tree
+ */
+@Since("1.4.0")
+class DecisionTreeRegressionModel private[ml] (
+ override val uid: String,
+ override val rootNode: Node,
+ override val numFeatures: Int)
+ extends RegressionModel[Vector, DecisionTreeRegressionModel]
+ with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
+
+ /** @group setParam */
+ def setVarianceCol(value: String): this.type = set(varianceCol, value)
+
+ require(rootNode != null,
+ "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.")
+
+ /**
+ * Construct a decision tree regression model.
+ *
+ * @param rootNode Root node of tree, with other nodes attached.
+ */
+ private[ml] def this(rootNode: Node, numFeatures: Int) =
+ this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
+
+ override def predict(features: Vector): Double = {
+ rootNode.predictImpl(features).prediction
+ }
+
+ /** We need to update this function if we ever add other impurity measures. */
+ protected def predictVariance(features: Vector): Double = {
+ rootNode.predictImpl(features).impurityStats.calculate()
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(varianceCol))
+ }
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ var predictionColNames = Seq.empty[String]
+ var predictionColumns = Seq.empty[Column]
+
+ if ($(predictionCol).nonEmpty) {
+ val predictUDF = udf { features: Vector => predict(features) }
+ predictionColNames :+= $(predictionCol)
+ predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(predictionCol), outputSchema($(predictionCol)).metadata)
+ }
+
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ val predictVarianceUDF = udf { features: Vector => predictVariance(features) }
+ predictionColNames :+= $(varianceCol)
+ predictionColumns :+= predictVarianceUDF(col($(featuresCol)))
+ .as($(varianceCol), outputSchema($(varianceCol)).metadata)
+ }
+
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => predictLeaf(features) }
+ predictionColNames :+= $(leafCol)
+ predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
+ }
+
+ if (predictionColNames.nonEmpty) {
+ dataset.withColumns(predictionColNames, predictionColumns)
+ } else {
+ this.logWarning(s"$uid: DecisionTreeRegressionModel.transform() does nothing" +
+ " because no output columns were set.")
+ dataset.toDF()
+ }
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
+ copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"DecisionTreeRegressionModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
+ s"numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @note Feature importance for single decision trees can have high variance due to
+ * correlated predictor variables. Consider using a [[RandomForestRegressor]]
+ * to determine feature importance instead.
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
+
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
+ override private[spark] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this)
+}
+
+@Since("2.0.0")
+object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeRegressionModel] =
+ new DecisionTreeRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressionModel = super.load(path)
+
+ private[DecisionTreeRegressionModel]
+ class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val (nodeData, _) = NodeData.build(instance.rootNode, 0)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
+ }
+ }
+
+ private class DecisionTreeRegressionModelReader
+ extends MLReader[DecisionTreeRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): DecisionTreeRegressionModel = {
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val root = loadTreeNodes(path, metadata, sparkSession)
+ val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeRegressor,
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): DecisionTreeRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression,
+ s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+ s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
+ new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
new file mode 100644
index 0000000..15725d3
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.GradientBoostedTrees
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Gradient-Boosted Trees (GBTs)
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
+ */
+@Since("1.4.0")
+class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+ extends Regressor[Vector, GBTRegressor, GBTRegressionModel]
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("gbtr"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ *
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = {
+ logWarning("GBTRegressor.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ // Parameters from GBTParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ // Parameters from GBTRegressorParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setLossType(value: String): this.type = set(lossType, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
+ extractInstances(dataset.filter(col($(validationIndicatorCol)))))
+ } else {
+ (extractInstances(dataset), null)
+ }
+
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, weightCol, impurity,
+ lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
+ minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval,
+ featureSubsetStrategy, validationIndicatorCol, validationTol)
+
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+ val (doUseAcc, setUseAccFlag) = super.getDoUseAcc
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ if (setUseAccFlag) {
+ GradientBoostedTrees.runWithValidationX(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), doUseAcc, Some(instr))
+ } else {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), Some(instr))
+ }
+ } else {
+ if (setUseAccFlag) {
+ GradientBoostedTrees.runX(trainDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), doUseAcc, Some(instr))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy), Some(instr))
+ }
+ }
+ baseLearners.foreach(copyValues(_))
+
+ val numFeatures = baseLearners.head.numFeatures
+ instr.logNumFeatures(numFeatures)
+
+ new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ @Since("1.4.0")
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
+}
+
+/**
+ * Gradient-Boosted Trees (GBTs)
+ * model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+@Since("1.4.0")
+class GBTRegressionModel private[ml](
+ override val uid: String,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double],
+ override val numFeatures: Int)
+ extends RegressionModel[Vector, GBTRegressionModel]
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
+
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
+ s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ /**
+ * Construct a GBTRegressionModel
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+ @Since("1.4.0")
+ def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
+ this(uid, _trees, _treeWeights, -1)
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
+
+ /**
+ * Number of trees in ensemble
+ */
+ @Since("2.0.0")
+ val getNumTrees: Int = trees.length
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = _treeWeights
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ var predictionColNames = Seq.empty[String]
+ var predictionColumns = Seq.empty[Column]
+
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
+
+ if ($(predictionCol).nonEmpty) {
+ val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
+ predictionColNames :+= $(predictionCol)
+ predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(featuresCol), outputSchema($(featuresCol)).metadata)
+ }
+
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
+ predictionColNames :+= $(leafCol)
+ predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
+ }
+
+ if (predictionColNames.nonEmpty) {
+ dataset.withColumns(predictionColNames, predictionColumns)
+ } else {
+ this.logWarning(s"$uid: GBTRegressionModel.transform() does nothing" +
+ " because no output columns were set.")
+ dataset.toDF()
+ }
+ }
+
+ override def predict(features: Vector): Double = {
+ // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
+ // Classifies by thresholding sum of weighted tree predictions
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
+ blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): GBTRegressionModel = {
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
+ extra).setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"GBTRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see `DecisionTreeRegressionModel.featureImportances`
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector =
+ TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false)
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
+ }
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ * @param loss The loss function used to compute error. Supported options: squared, absolute
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
+ val data = extractInstances(dataset)
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
+ convertToOldLossType(loss), OldAlgo.Regression)
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
+}
+
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees = treesData.map {
+ case (treeMetadata, root) =>
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ treeMetadata.getAndSetParams(tree)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTRegressor,
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): GBTRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent for each tree is null since there is no good way to set this.
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
+ }
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
new file mode 100644
index 0000000..bb74c56
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.RandomForest
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Random Forest
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@Since("1.4.0")
+class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+ extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel]
+ with RandomForestRegressorParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("rfr"))
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ // Parameters from TreeEnsembleParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ // Parameters from RandomForestParams:
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setBootstrap(value: Boolean): this.type = set(bootstrap, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setFeatureSubsetStrategy(value: String): this.type =
+ set(featureSubsetStrategy, value)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ override protected def train(
+ dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+
+ val instances = extractInstances(dataset)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+ strategy.bootstrap = $(bootstrap)
+
+ instr.logPipelineStage(this)
+ instr.logDataset(instances)
+ instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, leafCol, impurity,
+ numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
+ minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, cacheNodeIds,
+ checkpointInterval, bootstrap)
+
+ val trees = RandomForest
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .map(_.asInstanceOf[DecisionTreeRegressionModel])
+ trees.foreach(copyValues(_))
+
+ val numFeatures = trees.head.numFeatures
+ instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
+ new RandomForestRegressionModel(uid, trees, numFeatures)
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
+}
+
+@Since("1.4.0")
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
+ /** Accessor for supported impurity settings: variance */
+ @Since("1.4.0")
+ final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ @Since("1.4.0")
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ TreeEnsembleParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
+}
+
+/**
+ * Random Forest model for regression.
+ * It supports both continuous and categorical features.
+ *
+ * @param _trees Decision trees in the ensemble.
+ * @param numFeatures Number of features used by this model
+ */
+@Since("1.4.0")
+class RandomForestRegressionModel private[ml] (
+ override val uid: String,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ override val numFeatures: Int)
+ extends RegressionModel[Vector, RandomForestRegressionModel]
+ with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
+
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
+
+ /**
+ * Construct a random forest regression model, with all trees weighted equally.
+ *
+ * @param trees Component trees
+ */
+ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
+ this(Identifiable.randomUID("rfr"), trees, numFeatures)
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = _treeWeights
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ var predictionColNames = Seq.empty[String]
+ var predictionColumns = Seq.empty[Column]
+
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
+
+ if ($(predictionCol).nonEmpty) {
+ val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
+ predictionColNames :+= $(predictionCol)
+ predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(predictionCol), outputSchema($(predictionCol)).metadata)
+ }
+
+ if ($(leafCol).nonEmpty) {
+ val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
+ predictionColNames :+= $(leafCol)
+ predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
+ }
+
+ if (predictionColNames.nonEmpty) {
+ dataset.withColumns(predictionColNames, predictionColumns)
+ } else {
+ this.logWarning(s"$uid: RandomForestRegressionModel.transform() does nothing" +
+ " because no output columns were set.")
+ dataset.toDF()
+ }
+ }
+
+ override def predict(features: Vector): Double = {
+ // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
+ // Predict average of tree predictions.
+ // Ignore the weights since all are 1.0 for now.
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
+ }
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): RandomForestRegressionModel = {
+ copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
+ }
+
+ @Since("1.4.0")
+ override def toString: String = {
+ s"RandomForestRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures"
+ }
+
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see `DecisionTreeRegressionModel.featureImportances`
+ */
+ @Since("1.5.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
+ }
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
+}
+
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ treeMetadata.getAndSetParams(tree)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestRegressor,
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): RandomForestRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent for each tree is null since there is no good way to set this.
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
+ }
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionForest.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionForest.scala
new file mode 100644
index 0000000..0496c6c
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionForest.scala
@@ -0,0 +1,1360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.RandomForest.NodeIndexInfo
+import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object DecisionForest extends Logging with Serializable {
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `LabeledPoint`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[LabeledPoint],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long): Array[DecisionTreeModel] = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
+ }
+
+ // scalastyle:off
+ /**
+ * Train a random forest with metadata and splits. This method is mainly for GBT,
+ * in which bagged input can be reused among trees.
+ *
+ * @param baggedInput bagged training data: RDD of `BaggedPoint`
+ * @param metadata Learning and dataset metadata for DecisionTree.
+ * @return an unweighted set of trees
+ */
+ def runBagged(
+ baggedInput: RDD[BaggedPoint[TreePointY]],
+ metadata: DecisionTreeMetadata,
+ bcSplits: Broadcast[Array[Array[Split]]],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ extraParams: Option[DFExtraParams] = None): Array[DecisionTreeModel] = {
+ // scalastyle:on
+ val timer = new TimeTracker()
+ timer.start("total")
+
+ val sc = baggedInput.sparkContext
+
+ instr match {
+ case Some(instrumentation) =>
+ instrumentation.logNumFeatures(metadata.numFeatures)
+ instrumentation.logNumClasses(metadata.numClasses)
+ instrumentation.logNumExamples(metadata.numExamples)
+ instrumentation.logSumOfWeights(metadata.weightedNumExamples)
+ case None =>
+ logInfo(s"numFeatures: ${metadata.numFeatures}")
+ logInfo(s"numClasses: ${metadata.numClasses}")
+ logInfo(s"numExamples: ${metadata.numExamples}")
+ logInfo(s"weightedNumExamples: ${metadata.weightedNumExamples}")
+ }
+
+ timer.start("init")
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug(s"max memory usage for aggregates = $maxMemoryUsage bytes.")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ var nodeIds: RDD[Array[Int]] = null
+ var nodeIdCheckpointer: PeriodicRDDCheckpointer[Array[Int]] = null
+ if (strategy.useNodeIdCache) {
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ nodeIds = baggedInput.map { _ => Array.fill(numTrees)(1) }
+ nodeIdCheckpointer = new PeriodicRDDCheckpointer[Array[Int]](
+ strategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ /*
+ Stack of nodes to train: (treeIndex, node)
+ The reason this is a stack is that we train many trees at once, but we want to focus on
+ completing trees, rather than training all simultaneously. If we are splitting nodes from
+ 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+ training the same tree in the next iteration. This focus allows us to send fewer trees to
+ workers on each iteration; see topNodesForGroup below.
+ */
+ val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ for (treeIndex <- 0 until numTrees) {
+ nodeStack.prepend((treeIndex, topNodes(treeIndex)))
+ }
+
+ timer.stop("init")
+
+ while (nodeStack.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ DecisionForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"DecisionForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Only send trees to worker if they contain nodes being split this iteration.
+ val topNodesForGroup: Map[Int, LearningNode] =
+ nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ val bestSplit = DecisionForest.findBestSplits(baggedInput, metadata, topNodesForGroup,
+ nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds,
+ outputBestSplits = strategy.useNodeIdCache, extraParams)
+ if (strategy.useNodeIdCache) {
+ nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ if (strategy.useNodeIdCache) {
+ // Delete any remaining checkpoints used for node Id cache.
+ nodeIdCheckpointer.unpersistDataSet()
+ nodeIdCheckpointer.deleteAllCheckpoints()
+ }
+
+ val numFeatures = metadata.numFeatures
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
+ }
+ }
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map(rootNode =>
+ new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
+ }
+ }
+ }
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `Instance`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[Instance],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ exParams: Option[DFExtraParams] = None): Array[DecisionTreeModel] = {
+ val extraParams = if (exParams.isEmpty) {
+ DTUtils.parseExtraParams(input, strategy)
+ } else {
+ exParams.get
+ }
+ val timer = new TimeTracker()
+
+ timer.start("build metadata")
+ val metadata = DecisionTreeMetadata
+ .buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
+ timer.stop("build metadata")
+
+ val retaggedInput = input.retag(classOf[Instance])
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplits")
+ val splits = findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePointY.convertToTreeRDD(retaggedInput, splits, metadata)
+
+ val bcSplits = input.sparkContext.broadcast(splits)
+ val baggedInputOri = BaggedPoint
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap,
+ (tp: TreePointY) => tp.weight, seed = seed)
+ .setName("bagged tree points")
+ val baggedInput = DTUtils.transformBaggedRDD(baggedInputOri, extraParams)
+
+ val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
+ strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed, instr = instr, prune = prune, parentUID = parentUID,
+ extraParams = Some(extraParams))
+
+ baggedInput.unpersist()
+ bcSplits.destroy()
+
+ trees
+ }
+
+ /**
+ * Update node indices by newly found splits.
+ */
+ private def updateNodeIds(
+ input: RDD[BaggedPoint[TreePointY]],
+ nodeIds: RDD[Array[Int]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = {
+ require(nodeIds != null && bestSplits != null)
+ input.zip(nodeIds).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < bestSplits.length) {
+ val bestSplitsInTree = bestSplits(treeId)
+ if (bestSplitsInTree != null) {
+ val nodeId = ids(treeId)
+ bestSplitsInTree.get(nodeId).foreach { bestSplit =>
+ val featureId = bestSplit.featureIndex
+ val bin = point.datum.binnedFeatures(featureId)
+ val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) {
+ LearningNode.leftChildIndex(nodeId)
+ } else {
+ LearningNode.rightChildIndex(nodeId)
+ }
+ ids(treeId) = newNodeId
+ }
+ }
+ treeId += 1
+ }
+ ids
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointY,
+ splits: Array[Array[Split]],
+ unorderedFeatures: Set[Int],
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.length
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures(featureIndex)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
+ // Update the left or right bin for each split.
+ val numSplits = agg.metadata.numSplits(featureIndex)
+ val featureSplits = splits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
+ }
+ featureIndexIdx += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointY,
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val label = treePoint.label
+
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.length) {
+ val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
+ featureIndex += 1
+ }
+ }
+ }
+
+ // scalastyle:off
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[TreePoint]]
+ * @param metadata Learning and dataset metadata
+ * @param topNodesForGroup For each tree in group, tree index -> root node.
+ * Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ * @param bcSplits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIds an RDD of Array[Int] where each value in the array is the data
+ * point's node Id for a corresponding tree. This is used to prevent
+ * the need to pass the entire tree to the executors during the node
+ * stat aggregation phase.
+ */
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePointY]],
+ metadata: DecisionTreeMetadata,
+ topNodesForGroup: Map[Int, LearningNode],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ timer: TimeTracker = new TimeTracker,
+ nodeIds: RDD[Array[Int]] = null,
+ outputBestSplits: Boolean = false,
+ extraParams: Option[DFExtraParams] = None): Array[Map[Int, Split]] = {
+ // scalastyle:on
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ val useNodeIdCache = nodeIds != null
+
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug(s"numNodes = $numNodes")
+ logDebug(s"numFeatures = ${metadata.numFeatures}")
+ logDebug(s"numClasses = ${metadata.numClasses}")
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+ logDebug(s"isMulticlassWithCategoricalFeatures = " +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
+ logDebug(s"using nodeIdCache = $useNodeIdCache")
+
+ val groupInfo = DTUtils.getGroupInfo(numNodes, treeToNodeToIndexInfo,
+ extraParams.getOrElse(null), nodesForGroup)
+
+ /*
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointY],
+ splits: Array[Array[Split]],
+ sampleId: Short = 0): Unit = {
+ if (DTUtils.isValidNodeInfo(nodeInfo, agg, groupInfo, baggedPoint, sampleId)) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
+ }
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
+ }
+ }
+
+ /*
+ * Performs a sequential aggregation over a partition.
+ *
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return agg
+ */
+ def binSeqOp(
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointY],
+ splits: Array[Array[Split]],
+ sampleId: Short): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ if (DTUtils.isSubSampled(baggedPoint, groupInfo, treeIndex, sampleId)) {
+ val nodeIndex =
+ topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits, sampleId)
+ }
+ }
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePointY], Array[Int]),
+ splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ agg
+ }
+
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group.
+ */
+ def getNodeToFeatures(
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+ if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield a (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+ val partitionAggregates = if (useNodeIdCache) {
+
+ input.zip(nodeIds).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _, bcSplits.value))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
+ }
+ } else {
+ input.mapPartitions { points =>
+ val (firstPointOption, nodeStatsAggregators) =
+ DTUtils.initNodeStatsAgg(numNodes, nodeToFeaturesBc, metadata, points, groupInfo)
+ if (firstPointOption.isEmpty) {
+ Iterator.empty
+ } else {
+ val splits = bcSplits.value
+ val firstPoint = firstPointOption.get
+ val sampleId = firstPoint.sampleId
+ binSeqOp(nodeStatsAggregators, firstPoint, splits, sampleId)
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _, splits, sampleId))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex
+ .filter(v => RFUtils.isValidAgg(v._1)).map(_.swap)
+ }
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: ImpurityStats) =
+ binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats))
+ }.collectAsMap()
+ nodeToFeaturesBc.destroy()
+
+ timer.stop("chooseSplits")
+
+ val bestSplits = if (outputBestSplits) {
+ Array.ofDim[mutable.Map[Int, Split]](metadata.numTrees)
+ } else {
+ null
+ }
+
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug(s"best split = $split")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = $node")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ if (outputBestSplits) {
+ val bestSplitsInTree = bestSplits(treeIndex)
+ if (bestSplitsInTree == null) {
+ bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split)
+ } else {
+ bestSplitsInTree.update(nodeIndex, split)
+ }
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.rightChild.get))
+ }
+
+ logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
+ }
+ }
+
+ if (outputBestSplits) {
+ bestSplits.map { m => if (m == null) null else m.toMap }
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Calculate the impurity statistics for a given (feature, split) based upon left/right
+ * aggregates.
+ *
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
+ */
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ val totalCount = leftCount + rightCount
+
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
+ }
+
+ /**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+
+ // Calculate InformationGain and ImpurityStats if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
+ null
+ } else {
+ node.stats
+ }
+
+ val validFeatureSplits =
+ Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ val splitsAndImpurityInfo =
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIdx =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numCategories).map { featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+
+ logDebug(s"Centroids for categorical variable: " +
+ s"${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug(s"Sorted centroids for categorical variable = " +
+ s"${categoriesSortedByCentroid.mkString(",")}")
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ }
+
+ val (bestSplit, bestSplitStats) =
+ if (splitsAndImpurityInfo.isEmpty) {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+ val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+ (new ContinuousSplit(dummyFeatureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex)
+ (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ } else {
+ splitsAndImpurityInfo.maxBy(_._2.gain)
+ }
+ (bestSplit, bestSplitStats)
+ }
+
+ /**
+ * Returns splits for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[Instance]]
+ * @param metadata Learning and dataset metadata
+ * @param seed random seed
+ * @return Splits, an Array of [[Split]]
+ * of size (numFeatures, numSplits)
+ */
+ protected[tree] def findSplits(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ seed: Long): Array[Array[Split]] = {
+
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
+ val fraction = samplesFractionForFindSplits(metadata)
+ logDebug(s"fraction of data used for calculating quantiles = $fraction")
+ if (fraction < 1) {
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+ } else {
+ input
+ }
+ } else {
+ input.sparkContext.emptyRDD[Instance]
+ }
+
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
+ }
+
+ private def findSplitsBySorting(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits = if (continuousFeatures.nonEmpty) {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input.flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
+ }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = { case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = { case ((map1, c1), (map2, c2)) =>
+ map2.foreach { case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ }
+ ).map { case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ } else Map.empty[Int, Array[Split]]
+
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
+ metadata.setNumSplits(i, split.length)
+ split
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
+ }
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Splits are constructed as needed during training.
+ Array.empty[Split]
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param featureSamples feature values and sample weights of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Iterable[(Double, Double)],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ val valueWeights = new OpenHashMap[Double, Double]
+ var count = 0L
+ featureSamples.foreach { case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
+ }
+ findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex)
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param partValueWeights non-zero distinct values and their weights
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ partValueWeights: Map[Double, Double],
+ count: Long,
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = if (partValueWeights.isEmpty) {
+ Array.emptyDoubleArray
+ } else {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ val partNumSamples = partValueWeights.values.sum
+
+ // Calculate the expected number of samples for finding splits
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
+ // scale tolerance by number of samples with constant factor
+ // Note: constant factor was tuned by running some tests where there were no zero
+ // feature values and validating we are never within tolerance
+ val tolerance = Utils.EPSILON * count * 100
+ // add expected zero value count and get complete statistics
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueWeights + (0.0 -> (weightedNumSamples - partNumSamples))
+ } else {
+ partValueWeights
+ }
+
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ val possibleSplits = valueCounts.length - 1
+ if (possibleSplits == 0) {
+ // constant feature
+ Array.emptyDoubleArray
+ } else if (possibleSplits <= numSplits) {
+ // if possible splits is not enough or just enough, just return all possible splits
+ (1 to possibleSplits)
+ .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+ .toArray
+ } else {
+ // stride between splits
+ val stride: Double = weightedNumSamples / (numSplits + 1)
+ logDebug(s"stride = $stride")
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+ splits
+ }
+
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeStack Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ var groupDone = false
+ while (nodeStack.nonEmpty && !groupDone) {
+ val (treeIndex, node) = nodeStack.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = DecisionForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
+ nodeStack.remove(0)
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ } else {
+ groupDone = true
+ }
+ }
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
+ logWarning(f"[group] actualMemUsage: ${memUsage/(1024d*1024d)}%.2f MB," +
+ f" maxMemoryUsage: ${maxMemoryUsage/(1024d*1024d)}%.2f MB.")
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ *
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala
new file mode 100644
index 0000000..b9b461c
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala
@@ -0,0 +1,1361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.RandomForest.NodeIndexInfo
+import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object DecisionTreeBucket extends Logging with Serializable {
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `LabeledPoint`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[LabeledPoint],
+ strategy: OldStrategy,
+ seed: Long): Array[DecisionTreeModel] = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ run(instances, strategy, seed, None)
+ }
+
+ // scalastyle:off
+ /**
+ * Train a random forest with metadata and splits. This method is mainly for GBT,
+ * in which bagged input can be reused among trees.
+ *
+ * @param baggedInput bagged training data: RDD of `BaggedPoint`
+ * @param metadata Learning and dataset metadata for DecisionTree.
+ * @return an unweighted set of trees
+ */
+ def runBagged(
+ baggedInput: RDD[BaggedPoint[TreePointX]],
+ metadata: DecisionTreeMetadata,
+ bcSplits: Broadcast[Array[Array[Split]]],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ extraParams: Option[RFExtraParams] = None): Array[DecisionTreeModel] = {
+ // scalastyle:on
+ val timer = new TimeTracker()
+ timer.start("total")
+
+ val sc = baggedInput.sparkContext
+
+ instr match {
+ case Some(instrumentation) =>
+ instrumentation.logNumFeatures(metadata.numFeatures)
+ instrumentation.logNumClasses(metadata.numClasses)
+ instrumentation.logNumExamples(metadata.numExamples)
+ instrumentation.logSumOfWeights(metadata.weightedNumExamples)
+ case None =>
+ logInfo(s"numFeatures: ${metadata.numFeatures}")
+ logInfo(s"numClasses: ${metadata.numClasses}")
+ logInfo(s"numExamples: ${metadata.numExamples}")
+ logInfo(s"weightedNumExamples: ${metadata.weightedNumExamples}")
+ }
+
+ timer.start("init")
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug(s"max memory usage for aggregates = $maxMemoryUsage bytes.")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ var nodeIds: RDD[Array[Int]] = null
+ var nodeIdCheckpointer: PeriodicRDDCheckpointer[Array[Int]] = null
+ if (strategy.useNodeIdCache) {
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ nodeIds = baggedInput.map { _ => Array.fill(numTrees)(1) }
+ nodeIdCheckpointer = new PeriodicRDDCheckpointer[Array[Int]](
+ strategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ /*
+ Stack of nodes to train: (treeIndex, node)
+ The reason this is a stack is that we train many trees at once, but we want to focus on
+ completing trees, rather than training all simultaneously. If we are splitting nodes from
+ 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+ training the same tree in the next iteration. This focus allows us to send fewer trees to
+ workers on each iteration; see topNodesForGroup below.
+ */
+ val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ for (treeIndex <- 0 until numTrees) {
+ nodeStack.prepend((treeIndex, topNodes(treeIndex)))
+ }
+
+ timer.stop("init")
+
+ while (nodeStack.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ DecisionTreeBucket.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"DecisionTreeBucket selected empty nodesForGroup. Error for unknown reason.")
+
+ // Only send trees to worker if they contain nodes being split this iteration.
+ val topNodesForGroup: Map[Int, LearningNode] =
+ nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ val bestSplit = DecisionTreeBucket.findBestSplits(baggedInput, metadata, topNodesForGroup,
+ nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds,
+ outputBestSplits = strategy.useNodeIdCache, extraParams)
+ if (strategy.useNodeIdCache) {
+ nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ if (strategy.useNodeIdCache) {
+ // Delete any remaining checkpoints used for node Id cache.
+ nodeIdCheckpointer.unpersistDataSet()
+ nodeIdCheckpointer.deleteAllCheckpoints()
+ }
+
+ val numFeatures = metadata.numFeatures
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
+ }
+ }
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map(rootNode =>
+ new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
+ }
+ }
+ }
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `Instance`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[Instance],
+ strategy: OldStrategy,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ exParams: Option[RFExtraParams] = None): Array[DecisionTreeModel] = {
+ val extraParams = if (exParams.isEmpty) {
+ RFUtils.parseExtraParams(input, strategy)
+ } else {
+ exParams.get
+ }
+ val timer = new TimeTracker()
+
+ timer.start("build metadata")
+ val featureSubsetStrategy = "1"
+ var numTrees: Int = 0
+ val metadata = DecisionTreeMetadata
+ .buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
+ numTrees = metadata.numTrees
+ timer.stop("build metadata")
+
+ val binnedFeaturesType = BinnedFeaturesDataType.withName(extraParams.featuresDataType)
+ val retaggedInput = input.retag(classOf[Instance])
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplits")
+ val splits = findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePointX representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePointX.convertToTreeRDD(retaggedInput, splits, metadata, binnedFeaturesType)
+
+ val bcSplits = input.sparkContext.broadcast(splits)
+ val baggedInputOri = BaggedPoint
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap,
+ (tp: TreePointX) => tp.weight, seed = seed, oneFeaturePerTree = metadata.oneFeaturePerTree)
+ .setName("bagged tree points")
+ val baggedInput = RFUtils.transformBaggedRDD(baggedInputOri, extraParams)
+
+ val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
+ strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed, instr = instr, prune = prune, parentUID = parentUID,
+ extraParams = Some(extraParams))
+
+ baggedInput.unpersist()
+ bcSplits.destroy()
+
+ trees
+ }
+
+ /**
+ * Update node indices by newly found splits.
+ */
+ private def updateNodeIds(
+ input: RDD[BaggedPoint[TreePointX]],
+ nodeIds: RDD[Array[Int]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = {
+ require(nodeIds != null && bestSplits != null)
+ input.zip(nodeIds).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < bestSplits.length) {
+ val bestSplitsInTree = bestSplits(treeId)
+ if (bestSplitsInTree != null) {
+ val nodeId = ids(treeId)
+ bestSplitsInTree.get(nodeId).foreach { bestSplit =>
+ val featureId = bestSplit.featureIndex
+ val bin = point.datum.binnedFeatures.get(featureId)
+ val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) {
+ LearningNode.leftChildIndex(nodeId)
+ } else {
+ LearningNode.rightChildIndex(nodeId)
+ }
+ ids(treeId) = newNodeId
+ }
+ }
+ treeId += 1
+ }
+ ids
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointX,
+ splits: Array[Array[Split]],
+ unorderedFeatures: Set[Int],
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.length
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures.get(featureIndex)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
+ // Update the left or right bin for each split.
+ val numSplits = agg.metadata.numSplits(featureIndex)
+ val featureSplits = splits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures.get(featureIndex)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
+ }
+ featureIndexIdx += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointX,
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val label = treePoint.label
+
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.length) {
+ val binIndex = treePoint.binnedFeatures.get(featuresForNode.get.apply(featureIndexIdx))
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures.get(featureIndex)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
+ featureIndex += 1
+ }
+ }
+ }
+
+ // scalastyle:off
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[TreePointX]]
+ * @param metadata Learning and dataset metadata
+ * @param topNodesForGroup For each tree in group, tree index -> root node.
+ * Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ * @param bcSplits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIds an RDD of Array[Int] where each value in the array is the data
+ * point's node Id for a corresponding tree. This is used to prevent
+ * the need to pass the entire tree to the executors during the node
+ * stat aggregation phase.
+ */
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePointX]],
+ metadata: DecisionTreeMetadata,
+ topNodesForGroup: Map[Int, LearningNode],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ timer: TimeTracker = new TimeTracker,
+ nodeIds: RDD[Array[Int]] = null,
+ outputBestSplits: Boolean = false,
+ extraParams: Option[RFExtraParams] = None): Array[Map[Int, Split]] = {
+ // scalastyle:on
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ val useNodeIdCache = nodeIds != null
+
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug(s"numNodes = $numNodes")
+ logDebug(s"numFeatures = ${metadata.numFeatures}")
+ logDebug(s"numClasses = ${metadata.numClasses}")
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+ logDebug(s"isMulticlassWithCategoricalFeatures = " +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
+ logDebug(s"using nodeIdCache = $useNodeIdCache")
+
+ val groupInfo = RFUtils.getGroupInfo(numNodes, treeToNodeToIndexInfo, extraParams)
+
+ /*
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointX],
+ splits: Array[Array[Split]]): Unit = {
+ if (RFUtils.isValidNodeInfo(nodeInfo, agg)) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val numSamples = 1
+ val sampleWeight = baggedPoint.sampleWeight
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
+ }
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
+ }
+ }
+
+ /*
+ * Performs a sequential aggregation over a partition.
+ *
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return agg
+ */
+ def binSeqOp(
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointX],
+ splits: Array[Array[Split]],
+ sampleId: Short): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ if (RFUtils.isValidSample(baggedPoint, groupInfo, treeIndex, sampleId)) {
+ val nodeIndex =
+ topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ }
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePointX], Array[Int]),
+ splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ agg
+ }
+
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group.
+ */
+ def getNodeToFeatures(
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+ if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield a (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+ val partitionAggregates = if (useNodeIdCache) {
+
+ input.zip(nodeIds).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _, bcSplits.value))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
+ }
+ } else {
+ input.mapPartitions { points =>
+ val (firstPointOption, nodeStatsAggregators) =
+ RFUtils.initNodeStatsAgg(numNodes, nodeToFeaturesBc, metadata, points, groupInfo)
+ if (firstPointOption.isEmpty) {
+ Iterator.empty
+ } else {
+ val splits = bcSplits.value
+ val firstPoint = firstPointOption.get
+ val sampleId = firstPoint.sampleId
+ binSeqOp(nodeStatsAggregators, firstPoint, splits, sampleId)
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _, splits, sampleId))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex
+ .filter(v => RFUtils.isValidAgg(v._1)).map(_.swap)
+ }
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: ImpurityStats) =
+ binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats))
+ }.collectAsMap()
+ nodeToFeaturesBc.destroy()
+
+ timer.stop("chooseSplits")
+
+ val bestSplits = if (outputBestSplits) {
+ Array.ofDim[mutable.Map[Int, Split]](metadata.numTrees)
+ } else {
+ null
+ }
+
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug(s"best split = $split")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = $node")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ if (outputBestSplits) {
+ val bestSplitsInTree = bestSplits(treeIndex)
+ if (bestSplitsInTree == null) {
+ bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split)
+ } else {
+ bestSplitsInTree.update(nodeIndex, split)
+ }
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.rightChild.get))
+ }
+
+ logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
+ }
+ }
+
+ if (outputBestSplits) {
+ bestSplits.map { m => if (m == null) null else m.toMap }
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Calculate the impurity statistics for a given (feature, split) based upon left/right
+ * aggregates.
+ *
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
+ */
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ val totalCount = leftCount + rightCount
+
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
+ }
+
+ /**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+
+ // Calculate InformationGain and ImpurityStats if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
+ null
+ } else {
+ node.stats
+ }
+
+ val validFeatureSplits =
+ Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ val splitsAndImpurityInfo =
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIdx =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numCategories).map { featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+
+ logDebug(s"Centroids for categorical variable: " +
+ s"${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug(s"Sorted centroids for categorical variable = " +
+ s"${categoriesSortedByCentroid.mkString(",")}")
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ }
+
+ val (bestSplit, bestSplitStats) =
+ if (splitsAndImpurityInfo.isEmpty) {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+ val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+ (new ContinuousSplit(dummyFeatureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex)
+ (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ } else {
+ splitsAndImpurityInfo.maxBy(_._2.gain)
+ }
+ (bestSplit, bestSplitStats)
+ }
+
+ /**
+ * Returns splits for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[Instance]]
+ * @param metadata Learning and dataset metadata
+ * @param seed random seed
+ * @return Splits, an Array of [[Split]]
+ * of size (numFeatures, numSplits)
+ */
+ protected[tree] def findSplits(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ seed: Long): Array[Array[Split]] = {
+
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
+ val fraction = samplesFractionForFindSplits(metadata)
+ logDebug(s"fraction of data used for calculating quantiles = $fraction")
+ if (fraction < 1) {
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+ } else {
+ input
+ }
+ } else {
+ input.sparkContext.emptyRDD[Instance]
+ }
+
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
+ }
+
+ private def findSplitsBySorting(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits = if (continuousFeatures.nonEmpty) {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input.flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
+ }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = { case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = { case ((map1, c1), (map2, c2)) =>
+ map2.foreach { case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ }
+ ).map { case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ } else Map.empty[Int, Array[Split]]
+
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
+ metadata.setNumSplits(i, split.length)
+ split
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
+ }
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Splits are constructed as needed during training.
+ Array.empty[Split]
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param featureSamples feature values and sample weights of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Iterable[(Double, Double)],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ val valueWeights = new OpenHashMap[Double, Double]
+ var count = 0L
+ featureSamples.foreach { case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
+ }
+ findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex)
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param partValueWeights non-zero distinct values and their weights
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ partValueWeights: Map[Double, Double],
+ count: Long,
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = if (partValueWeights.isEmpty) {
+ Array.emptyDoubleArray
+ } else {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ val partNumSamples = partValueWeights.values.sum
+
+ // Calculate the expected number of samples for finding splits
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
+ // scale tolerance by number of samples with constant factor
+ // Note: constant factor was tuned by running some tests where there were no zero
+ // feature values and validating we are never within tolerance
+ val tolerance = Utils.EPSILON * count * 100
+ // add expected zero value count and get complete statistics
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueWeights + (0.0 -> (weightedNumSamples - partNumSamples))
+ } else {
+ partValueWeights
+ }
+
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ val possibleSplits = valueCounts.length - 1
+ if (possibleSplits == 0) {
+ // constant feature
+ Array.emptyDoubleArray
+ } else if (possibleSplits <= numSplits) {
+ // if possible splits is not enough or just enough, just return all possible splits
+ (1 to possibleSplits)
+ .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+ .toArray
+ } else {
+ // stride between splits
+ val stride: Double = weightedNumSamples / (numSplits + 1)
+ logDebug(s"stride = $stride")
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+ splits
+ }
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeStack Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ var groupDone = false
+ while (nodeStack.nonEmpty && !groupDone) {
+ val (treeIndex, node) = nodeStack.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ if (metadata.oneFeaturePerTree) {
+ Some(Array(treeIndex))
+ } else {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ }
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = DecisionTreeBucket.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
+ nodeStack.remove(0)
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ } else {
+ groupDone = true
+ }
+ }
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
+ logWarning(f"[group] actualMemUsage: ${memUsage/(1024d*1024d)}%.2f MB," +
+ f" maxMemoryUsage: ${maxMemoryUsage/(1024d*1024d)}%.2f MB.")
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ *
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
new file mode 100644
index 0000000..0937ef4
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Try
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.tree.TreeEnsembleParams
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.rdd.RDD
+
+/**
+ * Learning and dataset metadata for DecisionTree.
+ *
+ * @param weightedNumExamples Weighted count of samples in the tree.
+ * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
+ * For regression: fixed at 0 (no meaning).
+ * @param maxBins Maximum number of bins, for all features.
+ * @param featureArity Map: categorical feature index to arity.
+ * I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins Number of bins for each feature.
+ * @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
+ * present in a leaf node in order to be considered a valid split.
+ */
+private[spark] class DecisionTreeMetadata(
+ val numFeatures: Int,
+ val numExamples: Long,
+ val weightedNumExamples: Double,
+ val numClasses: Int,
+ val maxBins: Int,
+ val featureArity: Map[Int, Int],
+ val unorderedFeatures: Set[Int],
+ val numBins: Array[Int],
+ val impurity: Impurity,
+ val quantileStrategy: QuantileStrategy,
+ val maxDepth: Int,
+ val minInstancesPerNode: Int,
+ val minWeightFractionPerNode: Double,
+ val minInfoGain: Double,
+ val numTrees: Int,
+ val numFeaturesPerNode: Int,
+ val oneFeaturePerTree: Boolean = false) extends Serializable {
+
+ def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
+
+ def isClassification: Boolean = numClasses >= 2
+
+ def isMulticlass: Boolean = numClasses > 2
+
+ def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
+
+ def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
+
+ def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+
+ def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
+
+ /**
+ * Number of splits for the given feature.
+ * For unordered features, there is 1 bin per split.
+ * For ordered features, there is 1 more bin than split.
+ */
+ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+ numBins(featureIndex)
+ } else {
+ numBins(featureIndex) - 1
+ }
+
+
+ /**
+ * Set number of splits for a continuous feature.
+ * For a continuous feature, number of bins is number of splits plus 1.
+ */
+ def setNumSplits(featureIndex: Int, numSplits: Int): Unit = {
+ require(isContinuous(featureIndex),
+ s"Only number of bin for a continuous feature can be set.")
+ numBins(featureIndex) = numSplits + 1
+ }
+
+ /**
+ * Indicates if feature subsampling is being used.
+ */
+ def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
+
+}
+
+private[spark] object DecisionTreeMetadata extends Logging {
+
+ /**
+ * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
+ * This computes which categorical features will be ordered vs. unordered,
+ * as well as the number of splits and bins for each feature.
+ */
+ def buildMetadata(
+ input: RDD[Instance],
+ strategy: Strategy,
+ numTrees: Int,
+ featureSubsetStrategy: String): DecisionTreeMetadata = {
+
+ val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
+ throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
+ s"but was given by empty one.")
+ }
+ require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
+ s"but was given an empty features vector")
+ val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+ seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
+ combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
+ )
+
+ val numClasses = strategy.algo match {
+ case Classification => strategy.numClasses
+ case Regression => 0
+ }
+
+ val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+ if (maxPossibleBins < strategy.maxBins) {
+ logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+ s" (= number of training instances)")
+ }
+
+ // We check the number of bins here against maxPossibleBins.
+ // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
+ // based on the number of training examples.
+ if (strategy.categoricalFeaturesInfo.nonEmpty) {
+ val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+ val maxCategory =
+ strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
+ require(maxCategoriesPerFeature <= maxPossibleBins,
+ s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
+ s"number of values in each categorical feature, but categorical feature $maxCategory " +
+ s"has $maxCategoriesPerFeature values. Consider removing this and other categorical " +
+ "features with a large number of values, or add more training examples.")
+ }
+
+ val unorderedFeatures = new mutable.HashSet[Int]()
+ val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
+ if (numClasses > 2) {
+ // Multiclass classification
+ val maxCategoriesForUnorderedFeature =
+ ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ // Hack: If a categorical feature has only 1 category, we treat it as continuous.
+ // TODO(SPARK-9957): Handle this properly by filtering out those features.
+ if (numCategories > 1) {
+ // Decide if some categorical features should be treated as unordered features,
+ // which require 2 * ((1 << numCategories - 1) - 1) bins.
+ // We do this check with log values to prevent overflows in case numCategories is large.
+ // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+ if (numCategories <= maxCategoriesForUnorderedFeature) {
+ unorderedFeatures.add(featureIndex)
+ numBins(featureIndex) = numUnorderedBins(numCategories)
+ } else {
+ numBins(featureIndex) = numCategories
+ }
+ }
+ }
+ } else {
+ // Binary classification or regression
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
+ if (numCategories > 1) {
+ numBins(featureIndex) = numCategories
+ }
+ }
+ }
+
+ // Set number of features to use per node (for random forests).
+ val _featureSubsetStrategy = featureSubsetStrategy match {
+ case "auto" =>
+ if (numTrees == 1) {
+ "all"
+ } else {
+ if (strategy.algo == Classification) {
+ "sqrt"
+ } else {
+ "onethird"
+ }
+ }
+ case _ => featureSubsetStrategy
+ }
+
+ val numFeaturesPerNode: Int = _featureSubsetStrategy match {
+ case "all" => numFeatures
+ case "sqrt" => math.sqrt(numFeatures).ceil.toInt
+ case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case _ =>
+ Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match {
+ case Some(value) => math.min(value, numFeatures)
+ case None =>
+ Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
+ case Some(value) => math.ceil(value * numFeatures).toInt
+ case _ => throw new IllegalArgumentException(s"Supported values:" +
+ s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
+ }
+ }
+ }
+
+ val (newNumTrees, oneFeaturePerTree) = if (numTrees > 0) {
+ (numTrees, false)
+ } else {
+ (numFeatures, true)
+
+ }
+ new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+ numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+ strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
+ strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
+ newNumTrees, numFeaturesPerNode, oneFeaturePerTree)
+ }
+
+ /**
+ * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
+ */
+ def buildMetadata(
+ input: RDD[Instance],
+ strategy: Strategy): DecisionTreeMetadata = {
+ buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
+ }
+
+ /**
+ * Given the arity of a categorical feature (arity = number of categories),
+ * return the number of bins for the feature if it is to be treated as an unordered feature.
+ * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
+ * there are math.pow(2, arity - 1) - 1 such splits.
+ * Each split has 2 corresponding bins.
+ */
+ def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
+
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
new file mode 100644
index 0000000..688bcc1
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -0,0 +1,761 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList
+
+import org.apache.spark.SparkContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.Split
+import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
+import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
+import org.apache.spark.storage.StorageLevel
+
+
+private[spark] object GradientBoostedTrees extends Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of `Instance`.
+ * @param seed Random seed.
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def run(
+ input: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long,
+ featureSubsetStrategy: String,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val doUseAcc = getDoUseAccFromSparkConf(input.sparkContext)
+ runX(input, boostingStrategy, seed, featureSubsetStrategy, doUseAcc, instr)
+ }
+
+//scalastyle:off
+ /** Run with extended parameters */
+ def runX(
+ input: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long,
+ featureSubsetStrategy: String,
+ doUseAcc: Boolean,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case OldAlgo.Regression =>
+ if (doUseAcc) {
+ GradientBoostedTrees.boostX(input, input, boostingStrategy, validate = false,
+ seed, featureSubsetStrategy, instr)
+ } else {
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
+ seed, featureSubsetStrategy, instr)
+ }
+ case OldAlgo.Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(x => Instance((x.label * 2) - 1, x.weight, x.features))
+ if (doUseAcc) {
+ GradientBoostedTrees.boostX(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed, featureSubsetStrategy, instr)
+ } else {
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed, featureSubsetStrategy, instr)
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
+ }
+ }
+
+ /**
+ * Method to validate a gradient boosting model
+ * @param input Training dataset: RDD of `Instance`.
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using `org.apache.spark.rdd.RDD.randomSplit()`
+ * @param seed Random seed.
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def runWithValidation(
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long,
+ featureSubsetStrategy: String,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val doUseAcc = getDoUseAccFromSparkConf(input.sparkContext)
+ runWithValidationX(input, validationInput, boostingStrategy, seed, featureSubsetStrategy,
+ doUseAcc, instr)
+ }
+
+ /** Run with validation dataset and extended parameters */
+ def runWithValidationX(
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long,
+ featureSubsetStrategy: String,
+ doUseAcc: Boolean,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case OldAlgo.Regression =>
+ if (doUseAcc) {
+ GradientBoostedTrees.boostX(input, validationInput, boostingStrategy,
+ validate = true, seed, featureSubsetStrategy)
+ } else {
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
+ validate = true, seed, featureSubsetStrategy, instr)
+ }
+ case OldAlgo.Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(
+ x => Instance((x.label * 2) - 1, x.weight, x.features))
+ val remappedValidationInput = validationInput.map(
+ x => Instance((x.label * 2) - 1, x.weight, x.features))
+ if (doUseAcc) {
+ GradientBoostedTrees.boostX(remappedInput, remappedValidationInput, boostingStrategy,
+ validate = true, seed, featureSubsetStrategy)
+ } else {
+ GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
+ validate = true, seed, featureSubsetStrategy, instr)
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ private val extraParamKey = "spark.boostkit.ml.gbdt.doUseAcc"
+ private val doUseAccDefault = true
+
+ private def getDoUseAccFromSparkConf(sc: SparkContext): Boolean = {
+ val doUseAcctStr = sc.conf.getOption(extraParamKey)
+ if (doUseAcctStr.nonEmpty) {
+ try {
+ doUseAcctStr.get.toBoolean
+ } catch {
+ case ex: Exception =>
+ throw new IllegalArgumentException(s"Parse boostkit parameter" +
+ s"($extraParamKey) failed, Error reason: ${ex.getMessage}")
+ }
+ } else {
+ doUseAccDefault
+ }
+ }
+
+ /**
+ * Compute the initial predictions and errors for a dataset for the first
+ * iteration of gradient boosting.
+ * @param data: training data.
+ * @param initTreeWeight: learning rate assigned to the first tree.
+ * @param initTree: first DecisionTreeModel.
+ * @param loss: evaluation metric.
+ * @return an RDD with each element being a zip of the prediction and error
+ * corresponding to every sample.
+ */
+ def computeInitialPredictionAndError(
+ data: RDD[TreePoint],
+ initTreeWeight: Double,
+ initTree: DecisionTreeRegressionModel,
+ loss: OldLoss,
+ bcSplits: Broadcast[Array[Array[Split]]]): RDD[(Double, Double)] = {
+ data.map { treePoint =>
+ val pred = updatePrediction(treePoint, 0.0, initTree, initTreeWeight, bcSplits.value)
+ val error = loss.computeError(pred, treePoint.label)
+ (pred, error)
+ }
+ }
+
+ /**
+ * Update a zipped predictionError RDD
+ * (as obtained with computeInitialPredictionAndError)
+ * @param data: training data.
+ * @param predictionAndError: predictionError RDD
+ * @param treeWeight: Learning rate.
+ * @param tree: Tree using which the prediction and error should be updated.
+ * @param loss: evaluation metric.
+ * @return an RDD with each element being a zip of the prediction and error
+ * corresponding to each sample.
+ */
+ def updatePredictionError(
+ data: RDD[TreePoint],
+ predictionAndError: RDD[(Double, Double)],
+ treeWeight: Double,
+ tree: DecisionTreeRegressionModel,
+ loss: OldLoss,
+ bcSplits: Broadcast[Array[Array[Split]]]): RDD[(Double, Double)] = {
+ data.zip(predictionAndError).map { case (treePoint, (pred, _)) =>
+ val newPred = updatePrediction(treePoint, pred, tree, treeWeight, bcSplits.value)
+ val newError = loss.computeError(newPred, treePoint.label)
+ (newPred, newError)
+ }
+ }
+
+ /**
+ * Add prediction from a new boosting iteration to an existing prediction.
+ *
+ * @param treePoint Binned vector of features representing a single data point.
+ * @param prediction The existing prediction.
+ * @param tree New Decision Tree model.
+ * @param weight Tree weight.
+ * @return Updated prediction.
+ */
+ def updatePrediction(
+ treePoint: TreePoint,
+ prediction: Double,
+ tree: DecisionTreeRegressionModel,
+ weight: Double,
+ splits: Array[Array[Split]]): Double = {
+ prediction +
+ tree.rootNode.predictBinned(treePoint.binnedFeatures, splits).prediction * weight
+ }
+
+ /**
+ * Add prediction from a new boosting iteration to an existing prediction.
+ *
+ * @param features Vector of features representing a single data point.
+ * @param prediction The existing prediction.
+ * @param tree New Decision Tree model.
+ * @param weight Tree weight.
+ * @return Updated prediction.
+ */
+ def updatePrediction(
+ features: Vector,
+ prediction: Double,
+ tree: DecisionTreeRegressionModel,
+ weight: Double): Double = {
+ prediction + tree.rootNode.predictImpl(features).prediction * weight
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param data Training dataset: RDD of `Instance`.
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @return Measure of model error on data
+ */
+ def computeWeightedError(
+ data: RDD[Instance],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss): Double = {
+ val (errSum, weightSum) = data.map { case Instance(label, weight, features) =>
+ val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
+ updatePrediction(features, acc, model, weight)
+ }
+ (loss.computeError(predicted, label) * weight, weight)
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (err1 + err2, weight1 + weight2)
+ }
+ errSum / weightSum
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * @param data Training dataset: RDD of `TreePoint`.
+ * @param predError Prediction and error.
+ * @return Measure of model error on data
+ */
+ def computeWeightedError(
+ data: RDD[TreePoint],
+ predError: RDD[(Double, Double)]): Double = {
+ val (errSum, weightSum) = data.zip(predError).map {
+ case (treePoint, (_, err)) =>
+ (err * treePoint.weight, treePoint.weight)
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (err1 + err2, weight1 + weight2)
+ }
+ errSum / weightSum
+ }
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param data RDD of `Instance`
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @param algo algorithm for the ensemble, either Classification or Regression
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[Instance],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss,
+ algo: OldAlgo.Value): Array[Double] = {
+ val remappedData = algo match {
+ case OldAlgo.Classification =>
+ data.map(x => Instance((x.label * 2) - 1, x.weight, x.features))
+ case _ => data
+ }
+
+ val numTrees = trees.length
+ val (errSum, weightSum) = remappedData.mapPartitions { iter =>
+ iter.map { case Instance(label, weight, features) =>
+ val pred = Array.tabulate(numTrees) { i =>
+ trees(i).rootNode.predictImpl(features)
+ .prediction * treeWeights(i)
+ }
+ val err = pred.scanLeft(0.0)(_ + _).drop(1)
+ .map(p => loss.computeError(p, label) * weight)
+ (err, weight)
+ }
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (0 until numTrees).foreach(i => err1(i) += err2(i))
+ (err1, weight1 + weight2)
+ }
+
+ errSum.map(_ / weightSum)
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param validationInput validation dataset, ignored if validate is set to false.
+ * @param boostingStrategy boosting parameters
+ * @param validate whether or not to use the validation dataset.
+ * @param seed Random seed.
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def boost(
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ validate: Boolean,
+ seed: Long,
+ featureSubsetStrategy: String,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ val sc = input.sparkContext
+
+ boostingStrategy.assertValid()
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ val validationTol = boostingStrategy.validationTol
+ treeStrategy.algo = OldAlgo.Regression
+ treeStrategy.impurity = OldVariance
+ require(!treeStrategy.bootstrap, "GradientBoostedTrees does not need bootstrap sampling")
+ treeStrategy.assertValid()
+
+ // Prepare periodic checkpointers
+ // Note: this is checkpointing the unweighted training error
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val retaggedInput = input.retag(classOf[Instance])
+ timer.start("buildMetadata")
+ val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, treeStrategy,
+ numTrees = 1, featureSubsetStrategy)
+ timer.stop("buildMetadata")
+
+ timer.start("findSplits")
+ val splits = RandomForest.findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ val bcSplits = sc.broadcast(splits)
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treePoints = TreePoint.convertToTreeRDD(
+ retaggedInput, splits, metadata)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ .setName("binned tree points")
+
+ val firstCounts = BaggedPoint
+ .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
+ treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed)
+ .map { bagged =>
+ require(bagged.subsampleCounts.length == 1)
+ require(bagged.sampleWeight == bagged.datum.weight)
+ bagged.subsampleCounts.head
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ .setName("firstCounts at iter=0")
+
+ val firstBagged = treePoints.zip(firstCounts)
+ .map { case (treePoint, count) =>
+ // according to current design, treePoint.weight == baggedPoint.sampleWeight
+ new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
+ }
+
+ val firstTreeModel = RandomForestRaw.runBagged(baggedInput = firstBagged,
+ metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
+ featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
+ parentUID = None)
+ .head.asInstanceOf[DecisionTreeRegressionModel]
+
+ firstCounts.unpersist()
+
+ val firstTreeWeight = 1.0
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = firstTreeWeight
+
+ var predError = computeInitialPredictionAndError(
+ treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
+ predErrorCheckpointer.update(predError)
+ logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
+
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ var validationTreePoints: RDD[TreePoint] = null
+ var validatePredError: RDD[(Double, Double)] = null
+ var validatePredErrorCheckpointer: PeriodicRDDCheckpointer[(Double, Double)] = null
+ var bestValidateError = 0.0
+ if (validate) {
+ timer.start("init validation")
+ validationTreePoints = TreePoint.convertToTreeRDD(
+ validationInput.retag(classOf[Instance]), splits, metadata)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ validatePredError = computeInitialPredictionAndError(
+ validationTreePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
+ validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ validatePredErrorCheckpointer.update(validatePredError)
+ bestValidateError = computeWeightedError(validationTreePoints, validatePredError)
+ timer.stop("init validation")
+ }
+
+ var bestM = 1
+
+ var m = 1
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+
+ // (label: Double, count: Int)
+ val labelWithCounts = BaggedPoint
+ .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
+ treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed + m)
+ .zip(predError)
+ .map { case (bagged, (pred, _)) =>
+ require(bagged.subsampleCounts.length == 1)
+ require(bagged.sampleWeight == bagged.datum.weight)
+ // Update labels with pseudo-residuals
+ val newLabel = -loss.gradient(pred, bagged.datum.label)
+ (newLabel, bagged.subsampleCounts.head)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ .setName(s"labelWithCounts at iter=$m")
+
+ val bagged = treePoints.zip(labelWithCounts)
+ .map { case (treePoint, (newLabel, count)) =>
+ val newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures, treePoint.weight)
+ // according to current design, treePoint.weight == baggedPoint.sampleWeight
+ new BaggedPoint[TreePoint](newTreePoint, Array(count), treePoint.weight)
+ }
+
+ val model = RandomForestRaw.runBagged(baggedInput = bagged,
+ metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
+ numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed + m, instr = None, parentUID = None)
+ .head.asInstanceOf[DecisionTreeRegressionModel]
+
+ labelWithCounts.unpersist()
+
+ timer.stop(s"building tree $m")
+ // Update partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+
+ predError = updatePredictionError(
+ treePoints, predError, baseLearnerWeights(m),
+ baseLearners(m), loss, bcSplits)
+ predErrorCheckpointer.update(predError)
+ logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
+
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best validation error.
+
+ validatePredError = updatePredictionError(
+ validationTreePoints, validatePredError, baseLearnerWeights(m),
+ baseLearners(m), loss, bcSplits)
+ validatePredErrorCheckpointer.update(validatePredError)
+ val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
+ if (bestValidateError - currentValidateError < validationTol * Math.max(
+ currentValidateError, 0.01)) {
+ doneLearning = true
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
+ }
+ m += 1
+ }
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ bcSplits.destroy()
+ treePoints.unpersist()
+ predErrorCheckpointer.unpersistDataSet()
+ predErrorCheckpointer.deleteAllCheckpoints()
+ if (validate) {
+ validationTreePoints.unpersist()
+ validatePredErrorCheckpointer.unpersistDataSet()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
+ }
+
+ if (validate) {
+ (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
+ } else {
+ (baseLearners, baseLearnerWeights)
+ }
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param validationInput validation dataset, ignored if validate is set to false.
+ * @param boostingStrategy boosting parameters
+ * @param validate whether or not to use the validation dataset.
+ * @param seed Random seed.
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def boostX(
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
+ boostingStrategy: OldBoostingStrategy,
+ validate: Boolean,
+ seed: Long,
+ featureSubsetStrategy: String,
+ instr: Option[Instrumentation] = None):
+ (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ val sc = input.sparkContext
+
+ boostingStrategy.assertValid()
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ val validationTol = boostingStrategy.validationTol
+ treeStrategy.algo = OldAlgo.Regression
+ treeStrategy.impurity = OldVariance
+ require(!treeStrategy.bootstrap, "GradientBoostedTrees does not need bootstrap sampling")
+ treeStrategy.assertValid()
+
+ // Prepare periodic checkpointers
+ // Note: this is checkpointing the unweighted training error
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val retaggedInput = input.retag(classOf[Instance])
+ timer.start("buildMetadata")
+ val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, treeStrategy,
+ numTrees = 1, featureSubsetStrategy)
+ timer.stop("buildMetadata")
+
+ timer.start("findSplits")
+ val splits = RandomForest4GBDTX.findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ val bcSplits = sc.broadcast(splits)
+
+ val (treeInput, processedInput, labelArrayBcTmp, weightArrayBcTmp, rawPartInfoBcTmp) =
+ GradientBoostedTreesUtil.dataProcessX(retaggedInput, splits, treeStrategy, metadata, timer,
+ seed)
+ val rawPartInfoBc = rawPartInfoBcTmp
+ var labelArrayBc = labelArrayBcTmp
+ val weightArrayBc = weightArrayBcTmp
+ val weightArray = treeInput.map{treePoint => treePoint.weight}.collect()
+ val useWeight = if (weightArray.toSet.size != 1) {
+ (true, weightArray(0))
+ } else {
+ (false, weightArray(0))
+ }
+
+ val numTrees = 1
+ val paramsTuple = (treeStrategy, numTrees, seed)
+ val firstTreeModel = RandomForest4GBDTX.runX(labelArrayBc, processedInput, metadata, splits,
+ paramsTuple, treeInput, rawPartInfoBc, weightArrayBc, useWeight)
+ .head.asInstanceOf[DecisionTreeRegressionModel]
+
+ val firstTreeWeight = 1.0
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = firstTreeWeight
+
+ var predError = computeInitialPredictionAndError(
+ treeInput, firstTreeWeight, firstTreeModel, loss, bcSplits)
+ predErrorCheckpointer.update(predError)
+ logDebug(s"error of gbt = ${computeWeightedError(treeInput, predError)}")
+
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ var validationTreePoints: RDD[TreePoint] = null
+ var validatePredError: RDD[(Double, Double)] = null
+ var validatePredErrorCheckpointer: PeriodicRDDCheckpointer[(Double, Double)] = null
+ var bestValidateError = 0.0
+ if (validate) {
+ timer.start("init validation")
+ validationTreePoints = TreePoint.convertToTreeRDD(
+ validationInput.retag(classOf[Instance]), splits, metadata)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ validatePredError = computeInitialPredictionAndError(
+ validationTreePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
+ validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ validatePredErrorCheckpointer.update(validatePredError)
+ bestValidateError = computeWeightedError(validationTreePoints, validatePredError)
+ timer.stop("init validation")
+ }
+
+ var bestM = 1
+
+ var m = 1
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+
+ labelArrayBc = treeInput.sparkContext.broadcast(
+ DoubleArrayList.wrap(
+ predError.zip(treeInput).map { case ((pred, _), point) =>
+ -loss.gradient(pred, point.label)}.collect()
+ )
+ )
+
+ val paramsTuple = (treeStrategy, numTrees, seed + m)
+ val model = RandomForest4GBDTX.runX(labelArrayBc, processedInput, metadata, splits, paramsTuple,
+ treeInput, rawPartInfoBc, weightArrayBc, useWeight)
+ .head.asInstanceOf[DecisionTreeRegressionModel]
+
+ timer.stop(s"building tree $m")
+ // Update partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+
+ predError = updatePredictionError(
+ treeInput, predError, baseLearnerWeights(m),
+ baseLearners(m), loss, bcSplits)
+ predErrorCheckpointer.update(predError)
+ logDebug(s"error of gbt = ${computeWeightedError(treeInput, predError)}")
+
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best validation error.
+
+ validatePredError = updatePredictionError(
+ validationTreePoints, validatePredError, baseLearnerWeights(m),
+ baseLearners(m), loss, bcSplits)
+ validatePredErrorCheckpointer.update(validatePredError)
+ val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
+ if (bestValidateError - currentValidateError < validationTol * Math.max(
+ currentValidateError, 0.01)) {
+ doneLearning = true
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
+ }
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ bcSplits.destroy()
+ treeInput.unpersist()
+ predErrorCheckpointer.unpersistDataSet()
+ predErrorCheckpointer.deleteAllCheckpoints()
+ if (validate) {
+ validationTreePoints.unpersist()
+ validatePredErrorCheckpointer.unpersistDataSet()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
+ }
+
+ if (validate) {
+ (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
+ } else {
+ (baseLearners, baseLearnerWeights)
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
new file mode 100644
index 0000000..27994d4
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -0,0 +1,1361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object RandomForest extends Logging with Serializable {
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `LabeledPoint`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[LabeledPoint],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long): Array[DecisionTreeModel] = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
+ }
+
+ // scalastyle:off
+ /**
+ * Train a random forest with metadata and splits. This method is mainly for GBT,
+ * in which bagged input can be reused among trees.
+ *
+ * @param baggedInput bagged training data: RDD of `BaggedPoint`
+ * @param metadata Learning and dataset metadata for DecisionTree.
+ * @return an unweighted set of trees
+ */
+ def runBagged(
+ baggedInput: RDD[BaggedPoint[TreePointX]],
+ metadata: DecisionTreeMetadata,
+ bcSplits: Broadcast[Array[Array[Split]]],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ extraParams: Option[RFExtraParams] = None): Array[DecisionTreeModel] = {
+ // scalastyle:on
+ val timer = new TimeTracker()
+ timer.start("total")
+
+ val sc = baggedInput.sparkContext
+
+ instr match {
+ case Some(instrumentation) =>
+ instrumentation.logNumFeatures(metadata.numFeatures)
+ instrumentation.logNumClasses(metadata.numClasses)
+ instrumentation.logNumExamples(metadata.numExamples)
+ instrumentation.logSumOfWeights(metadata.weightedNumExamples)
+ case None =>
+ logInfo(s"numFeatures: ${metadata.numFeatures}")
+ logInfo(s"numClasses: ${metadata.numClasses}")
+ logInfo(s"numExamples: ${metadata.numExamples}")
+ logInfo(s"weightedNumExamples: ${metadata.weightedNumExamples}")
+ }
+
+ timer.start("init")
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug(s"max memory usage for aggregates = $maxMemoryUsage bytes.")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ var nodeIds: RDD[Array[Int]] = null
+ var nodeIdCheckpointer: PeriodicRDDCheckpointer[Array[Int]] = null
+ if (strategy.useNodeIdCache) {
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ nodeIds = baggedInput.map { _ => Array.fill(numTrees)(1) }
+ nodeIdCheckpointer = new PeriodicRDDCheckpointer[Array[Int]](
+ strategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ /*
+ Stack of nodes to train: (treeIndex, node)
+ The reason this is a stack is that we train many trees at once, but we want to focus on
+ completing trees, rather than training all simultaneously. If we are splitting nodes from
+ 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+ training the same tree in the next iteration. This focus allows us to send fewer trees to
+ workers on each iteration; see topNodesForGroup below.
+ */
+ val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ for (treeIndex <- 0 until numTrees) {
+ nodeStack.prepend((treeIndex, topNodes(treeIndex)))
+ }
+
+ timer.stop("init")
+
+ while (nodeStack.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Only send trees to worker if they contain nodes being split this iteration.
+ val topNodesForGroup: Map[Int, LearningNode] =
+ nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup,
+ nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds,
+ outputBestSplits = strategy.useNodeIdCache, extraParams)
+ if (strategy.useNodeIdCache) {
+ nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ if (strategy.useNodeIdCache) {
+ // Delete any remaining checkpoints used for node Id cache.
+ nodeIdCheckpointer.unpersistDataSet()
+ nodeIdCheckpointer.deleteAllCheckpoints()
+ }
+
+ val numFeatures = metadata.numFeatures
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
+ }
+ }
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map(rootNode =>
+ new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
+ }
+ }
+ }
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `Instance`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[Instance],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None,
+ exParams: Option[RFExtraParams] = None): Array[DecisionTreeModel] = {
+ val extraParams = if (exParams.isEmpty) {
+ RFUtils.parseExtraParams(input, strategy)
+ } else {
+ exParams.get
+ }
+ val timer = new TimeTracker()
+
+ timer.start("build metadata")
+ val metadata = DecisionTreeMetadata
+ .buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
+ timer.stop("build metadata")
+
+ val binnedFeaturesType = BinnedFeaturesDataType.withName(extraParams.featuresDataType)
+ val retaggedInput = input.retag(classOf[Instance])
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplits")
+ val splits = findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePointX representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePointX.convertToTreeRDD(retaggedInput, splits, metadata, binnedFeaturesType)
+
+ val bcSplits = input.sparkContext.broadcast(splits)
+ val baggedInputOri = BaggedPoint
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap,
+ (tp: TreePointX) => tp.weight, seed = seed)
+ .setName("bagged tree points")
+ val baggedInput = RFUtils.transformBaggedRDD(baggedInputOri, extraParams)
+
+ val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
+ strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed, instr = instr, prune = prune, parentUID = parentUID,
+ extraParams = Some(extraParams))
+
+ baggedInput.unpersist()
+ bcSplits.destroy()
+
+ trees
+ }
+
+ /**
+ * Update node indices by newly found splits.
+ */
+ private def updateNodeIds(
+ input: RDD[BaggedPoint[TreePointX]],
+ nodeIds: RDD[Array[Int]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = {
+ require(nodeIds != null && bestSplits != null)
+ input.zip(nodeIds).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < bestSplits.length) {
+ val bestSplitsInTree = bestSplits(treeId)
+ if (bestSplitsInTree != null) {
+ val nodeId = ids(treeId)
+ bestSplitsInTree.get(nodeId).foreach { bestSplit =>
+ val featureId = bestSplit.featureIndex
+ val bin = point.datum.binnedFeatures.get(featureId)
+ val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) {
+ LearningNode.leftChildIndex(nodeId)
+ } else {
+ LearningNode.rightChildIndex(nodeId)
+ }
+ ids(treeId) = newNodeId
+ }
+ }
+ treeId += 1
+ }
+ ids
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointX,
+ splits: Array[Array[Split]],
+ unorderedFeatures: Set[Int],
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.length
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures.get(featureIndex)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
+ // Update the left or right bin for each split.
+ val numSplits = agg.metadata.numSplits(featureIndex)
+ val featureSplits = splits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures.get(featureIndex)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
+ }
+ featureIndexIdx += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePointX,
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val label = treePoint.label
+
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.length) {
+ val binIndex = treePoint.binnedFeatures.get(featuresForNode.get.apply(featureIndexIdx))
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures.get(featureIndex)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
+ featureIndex += 1
+ }
+ }
+ }
+
+ // scalastyle:off
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[TreePointX]]
+ * @param metadata Learning and dataset metadata
+ * @param topNodesForGroup For each tree in group, tree index -> root node.
+ * Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ * @param bcSplits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIds an RDD of Array[Int] where each value in the array is the data
+ * point's node Id for a corresponding tree. This is used to prevent
+ * the need to pass the entire tree to the executors during the node
+ * stat aggregation phase.
+ */
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePointX]],
+ metadata: DecisionTreeMetadata,
+ topNodesForGroup: Map[Int, LearningNode],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ timer: TimeTracker = new TimeTracker,
+ nodeIds: RDD[Array[Int]] = null,
+ outputBestSplits: Boolean = false,
+ extraParams: Option[RFExtraParams] = None): Array[Map[Int, Split]] = {
+ // scalastyle:on
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ val useNodeIdCache = nodeIds != null
+
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug(s"numNodes = $numNodes")
+ logDebug(s"numFeatures = ${metadata.numFeatures}")
+ logDebug(s"numClasses = ${metadata.numClasses}")
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+ logDebug(s"isMulticlassWithCategoricalFeatures = " +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
+ logDebug(s"using nodeIdCache = $useNodeIdCache")
+
+ val groupInfo = RFUtils.getGroupInfo(numNodes, treeToNodeToIndexInfo, extraParams)
+
+ /*
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointX],
+ splits: Array[Array[Split]]): Unit = {
+ if (RFUtils.isValidNodeInfo(nodeInfo, agg)) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
+ }
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
+ }
+ }
+
+ /*
+ * Performs a sequential aggregation over a partition.
+ *
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return agg
+ */
+ def binSeqOp(
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePointX],
+ splits: Array[Array[Split]],
+ sampleId: Short): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ if (RFUtils.isSubSampled(baggedPoint, groupInfo, treeIndex, sampleId)) {
+ val nodeIndex =
+ topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ }
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePointX], Array[Int]),
+ splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ agg
+ }
+
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group.
+ */
+ def getNodeToFeatures(
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+ if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield a (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+ val partitionAggregates = if (useNodeIdCache) {
+
+ input.zip(nodeIds).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _, bcSplits.value))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
+ }
+ } else {
+ input.mapPartitions { points =>
+ val (firstPointOption, nodeStatsAggregators) =
+ RFUtils.initNodeStatsAgg(numNodes, nodeToFeaturesBc, metadata, points, groupInfo)
+ if (firstPointOption.isEmpty) {
+ Iterator.empty
+ } else {
+ val splits = bcSplits.value
+ val firstPoint = firstPointOption.get
+ val sampleId = firstPoint.sampleId
+ binSeqOp(nodeStatsAggregators, firstPoint, splits, sampleId)
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _, splits, sampleId))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex
+ .filter(v => RFUtils.isValidAgg(v._1)).map(_.swap)
+ }
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: ImpurityStats) =
+ binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats))
+ }.collectAsMap()
+ nodeToFeaturesBc.destroy()
+
+ timer.stop("chooseSplits")
+
+ val bestSplits = if (outputBestSplits) {
+ Array.ofDim[mutable.Map[Int, Split]](metadata.numTrees)
+ } else {
+ null
+ }
+
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug(s"best split = $split")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = $node")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ if (outputBestSplits) {
+ val bestSplitsInTree = bestSplits(treeIndex)
+ if (bestSplitsInTree == null) {
+ bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split)
+ } else {
+ bestSplitsInTree.update(nodeIndex, split)
+ }
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.rightChild.get))
+ }
+
+ logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
+ }
+ }
+
+ if (outputBestSplits) {
+ bestSplits.map { m => if (m == null) null else m.toMap }
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Calculate the impurity statistics for a given (feature, split) based upon left/right
+ * aggregates.
+ *
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
+ */
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ val totalCount = leftCount + rightCount
+
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
+ }
+
+ /**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+
+ // Calculate InformationGain and ImpurityStats if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
+ null
+ } else {
+ node.stats
+ }
+
+ val validFeatureSplits =
+ Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ val splitsAndImpurityInfo =
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIdx =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numCategories).map { featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+
+ logDebug(s"Centroids for categorical variable: " +
+ s"${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug(s"Sorted centroids for categorical variable = " +
+ s"${categoriesSortedByCentroid.mkString(",")}")
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ }
+
+ val (bestSplit, bestSplitStats) =
+ if (splitsAndImpurityInfo.isEmpty) {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+ val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+ (new ContinuousSplit(dummyFeatureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex)
+ (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ } else {
+ splitsAndImpurityInfo.maxBy(_._2.gain)
+ }
+ (bestSplit, bestSplitStats)
+ }
+
+ /**
+ * Returns splits for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[Instance]]
+ * @param metadata Learning and dataset metadata
+ * @param seed random seed
+ * @return Splits, an Array of [[Split]]
+ * of size (numFeatures, numSplits)
+ */
+ protected[tree] def findSplits(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ seed: Long): Array[Array[Split]] = {
+
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
+ val fraction = samplesFractionForFindSplits(metadata)
+ logDebug(s"fraction of data used for calculating quantiles = $fraction")
+ if (fraction < 1) {
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+ } else {
+ input
+ }
+ } else {
+ input.sparkContext.emptyRDD[Instance]
+ }
+
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
+ }
+
+ private def findSplitsBySorting(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits = if (continuousFeatures.nonEmpty) {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input.flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
+ }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = { case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = { case ((map1, c1), (map2, c2)) =>
+ map2.foreach { case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ }
+ ).map { case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ } else Map.empty[Int, Array[Split]]
+
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
+ metadata.setNumSplits(i, split.length)
+ split
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
+ }
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Splits are constructed as needed during training.
+ Array.empty[Split]
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param featureSamples feature values and sample weights of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Iterable[(Double, Double)],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ val valueWeights = new OpenHashMap[Double, Double]
+ var count = 0L
+ featureSamples.foreach { case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
+ }
+ findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex)
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param partValueWeights non-zero distinct values and their weights
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ partValueWeights: Map[Double, Double],
+ count: Long,
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = if (partValueWeights.isEmpty) {
+ Array.emptyDoubleArray
+ } else {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ val partNumSamples = partValueWeights.values.sum
+
+ // Calculate the expected number of samples for finding splits
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
+ // scale tolerance by number of samples with constant factor
+ // Note: constant factor was tuned by running some tests where there were no zero
+ // feature values and validating we are never within tolerance
+ val tolerance = Utils.EPSILON * count * 100
+ // add expected zero value count and get complete statistics
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueWeights + (0.0 -> (weightedNumSamples - partNumSamples))
+ } else {
+ partValueWeights
+ }
+
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ val possibleSplits = valueCounts.length - 1
+ if (possibleSplits == 0) {
+ // constant feature
+ Array.emptyDoubleArray
+ } else if (possibleSplits <= numSplits) {
+ // if possible splits is not enough or just enough, just return all possible splits
+ (1 to possibleSplits)
+ .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+ .toArray
+ } else {
+ // stride between splits
+ val stride: Double = weightedNumSamples / (numSplits + 1)
+ logDebug(s"stride = $stride")
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+ splits
+ }
+
+ private[tree] class NodeIndexInfo(
+ val nodeIndexInGroup: Int,
+ val featureSubset: Option[Array[Int]]) extends Serializable
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeStack Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ var groupDone = false
+ while (nodeStack.nonEmpty && !groupDone) {
+ val (treeIndex, node) = nodeStack.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
+ nodeStack.remove(0)
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ } else {
+ groupDone = true
+ }
+ }
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
+ logWarning(f"[group] actualMemUsage: ${memUsage/(1024d*1024d)}%.2f MB," +
+ f" maxMemoryUsage: ${maxMemoryUsage/(1024d*1024d)}%.2f MB.")
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ *
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest4GBDTX.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest4GBDTX.scala
new file mode 100644
index 0000000..ff23fd9
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest4GBDTX.scala
@@ -0,0 +1,689 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Random
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList
+import it.unimi.dsi.fastutil.ints.{Int2ObjectOpenHashMap, IntArrayList}
+import it.unimi.dsi.fastutil.objects.ObjectArrayList
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.impl.GradientBoostedTreesCore.NodeIndexInfo
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object RandomForest4GBDTX extends Logging with Serializable {
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `LabeledPoint`
+ * @return an unweighted set of trees
+ */
+ def runX(
+ labelArrayBc: Broadcast[DoubleArrayList],
+ processedInput: RDD[(Int, (IntArrayList, ObjectArrayList[Split]))],
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]],
+ paramsTuple: (OldStrategy, Int, Long),
+ input: RDD[TreePoint],
+ rawPartInfoBc: Broadcast[Int2ObjectOpenHashMap[IntArrayList]],
+ sampleWeightArrayBc: Broadcast[DoubleArrayList],
+ useWeight: (Boolean, Double),
+ parentUID: Option[String] = None): Array[DecisionTreeModel] = {
+
+ val timer = new TimeTracker()
+
+ timer.start("total")
+
+ timer.start("init")
+
+ val strategy = paramsTuple._1
+ val numTrees = paramsTuple._2
+ val seed = paramsTuple._3
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug(s"max memory usage for aggregates = ${maxMemoryUsage} bytes.")
+
+ /*
+ Stack of nodes to train: (treeIndex, node)
+ The reason this is a stack is that we train many trees at once, but we want to focus on
+ completing trees, rather than training all simultaneously. If we are splitting nodes from
+ 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+ training the same tree in the next iteration. This focus allows us to send fewer trees to
+ workers on each iteration; see topNodesForGroup below.
+ */
+ val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))
+
+ val nodeIdCacheX = GradientBoostedTreesUtil.nodeIdCacheXConstruction(topNodes, rawPartInfoBc)
+ timer.stop("init")
+
+ while (nodeStack.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ RandomForest4GBDTX.selectNodesToSplitX(nodeStack, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Only send trees to worker if they contain nodes being split this iteration.
+ val topNodesForGroup: Map[Int, LearningNode] =
+ nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ RandomForest4GBDTX.findBestSplitsX(labelArrayBc, processedInput, metadata,
+ (nodesForGroup, treeToNodeToIndexInfo), splits, nodeStack, nodeIdCacheX, input,
+ rawPartInfoBc, timer, sampleWeightArrayBc, useWeight)
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ val numFeatures = metadata.numFeatures
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ // unreachable for GBDT
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
+ }
+ }
+ // unreachable for GBDT
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
+ }
+ }
+ }
+
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[TreePoint]]
+ * @param metadata Learning and dataset metadata
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
+ */
+ // scalastyle:off
+ private[tree] def findBestSplitsX(
+ labelArrayBc: Broadcast[DoubleArrayList],
+ processedInput: RDD[(Int, (IntArrayList, ObjectArrayList[Split]))],
+ metadata: DecisionTreeMetadata,
+ packagedNodeInfo: (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]),
+ splits: Array[Array[Split]],
+ nodeStack: mutable.ArrayStack[(Int, LearningNode)],
+ nodeIdCache: Int2ObjectOpenHashMap[Int2ObjectOpenHashMap[IntArrayList]],
+ input: RDD[TreePoint],
+ rawPartInfoBc: Broadcast[Int2ObjectOpenHashMap[IntArrayList]],
+ timer: TimeTracker = new TimeTracker,
+ sampleWeightArrayBc: Broadcast[DoubleArrayList],
+ useWeight: (Boolean, Double)) : Unit = {
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ // Un-package node info
+ val (nodesForGroup, treeToNodeToIndexInfo) = packagedNodeInfo
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug(s"numNodes = ${numNodes}")
+ logDebug(s"numFeatures = ${metadata.numFeatures}")
+ logDebug(s"numClasses = ${metadata.numClasses}")
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+ logDebug(s"isMulticlassWithCategoricalFeatures =" +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ timer.start("broadcast")
+ val nodeIdCacheBc = processedInput.sparkContext.broadcast(nodeIdCache)
+ timer.stop("broadcast")
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ val nodeToBestSplits = GradientBoostedTreesUtil.chooseBestSplits(processedInput,
+ treeToNodeToIndexInfo, metadata, nodeIdCacheBc, labelArrayBc, nodes, sampleWeightArrayBc, useWeight)
+
+ timer.stop("chooseSplits")
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(nodeIndex)
+ logDebug(s"best split = ${split}")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = ${node}")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+ val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.push((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.push((treeIndex, node.rightChild.get))
+ }
+
+ logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
+ }
+ }
+
+ GradientBoostedTreesUtil.updateNodeIdCache(nodeIdCache, nodeIdCacheBc, input, nodesForGroup,
+ treeToNodeToIndexInfo, splits, rawPartInfoBc, metadata, timer)
+ }
+
+ /**
+ * Returns splits for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[Instance]]
+ * @param metadata Learning and dataset metadata
+ * @param seed random seed
+ * @return Splits, an Array of [[Split]]
+ * of size (numFeatures, numSplits)
+ */
+ protected[tree] def findSplits(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ seed: Long): Array[Array[Split]] = {
+
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
+ val fraction = samplesFractionForFindSplits(metadata)
+ logDebug(s"fraction of data used for calculating quantiles = $fraction")
+ if (fraction < 1) {
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+ } else {
+ input
+ }
+ } else {
+ input.sparkContext.emptyRDD[Instance]
+ }
+
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
+ }
+
+ private def findSplitsBySorting(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits = if (continuousFeatures.nonEmpty) {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input.flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
+ }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = { case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = { case ((map1, c1), (map2, c2)) =>
+ map2.foreach { case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ }
+ ).map { case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ } else Map.empty[Int, Array[Split]]
+
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
+ metadata.setNumSplits(i, split.length)
+ split
+
+ // unreachable for GBDT
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
+ }
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Splits are constructed as needed during training.
+ Array.empty[Split]
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param featureSamples feature values and sample weights of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Iterable[(Double, Double)],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ val valueWeights = new OpenHashMap[Double, Double]
+ var count = 0L
+ featureSamples.foreach { case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
+ }
+ findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex)
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param partValueWeights non-zero distinct values and their weights
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ partValueWeights: Map[Double, Double],
+ count: Long,
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = if (partValueWeights.isEmpty) {
+ Array.emptyDoubleArray
+ } else {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ val partNumSamples = partValueWeights.values.sum
+
+ // Calculate the expected number of samples for finding splits
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
+ // scale tolerance by number of samples with constant factor
+ // Note: constant factor was tuned by running some tests where there were no zero
+ // feature values and validating we are never within tolerance
+ val tolerance = Utils.EPSILON * count * 100
+ // add expected zero value count and get complete statistics
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueWeights + (0.0 -> (weightedNumSamples - partNumSamples))
+ } else {
+ partValueWeights
+ }
+
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ val possibleSplits = valueCounts.length - 1
+ if (possibleSplits == 0) {
+ // constant feature
+ Array.emptyDoubleArray
+ } else if (possibleSplits <= numSplits) {
+ // if possible splits is not enough or just enough, just return all possible splits
+ (1 to possibleSplits)
+ .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+ .toArray
+ } else {
+ // stride between splits
+ val stride: Double = weightedNumSamples / (numSplits + 1)
+ logDebug(s"stride = $stride")
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+ splits
+ }
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeStack Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplitX(
+ nodeStack: mutable.ArrayStack[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ var groupDone = false
+ while (nodeStack.nonEmpty && !groupDone) {
+ val (treeIndex, node) = nodeStack.top
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ } else {
+ None
+ }
+ val featureSubsetHashSetX: Option[mutable.HashSet[Int]] = if (metadata.subsamplingFeatures) {
+ Some(scala.collection.mutable.HashSet(featureSubset.get: _*))
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = RandomForest4GBDTX.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
+ nodeStack.pop()
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset, featureSubsetHashSetX)
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ } else {
+ groupDone = true
+ }
+ }
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ *
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForestRaw.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForestRaw.scala
new file mode 100644
index 0000000..68a7699
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/RandomForestRaw.scala
@@ -0,0 +1,1337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object RandomForestRaw extends Logging with Serializable {
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `LabeledPoint`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[LabeledPoint],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long): Array[DecisionTreeModel] = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
+ }
+
+ /**
+ * Train a random forest with metadata and splits. This method is mainly for GBT,
+ * in which bagged input can be reused among trees.
+ *
+ * @param baggedInput bagged training data: RDD of `BaggedPoint`
+ * @param metadata Learning and dataset metadata for DecisionTree.
+ * @return an unweighted set of trees
+ */
+ def runBagged(
+ baggedInput: RDD[BaggedPoint[TreePoint]],
+ metadata: DecisionTreeMetadata,
+ bcSplits: Broadcast[Array[Array[Split]]],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None): Array[DecisionTreeModel] = {
+ val timer = new TimeTracker()
+ timer.start("total")
+
+ val sc = baggedInput.sparkContext
+
+ instr match {
+ case Some(instrumentation) =>
+ instrumentation.logNumFeatures(metadata.numFeatures)
+ instrumentation.logNumClasses(metadata.numClasses)
+ instrumentation.logNumExamples(metadata.numExamples)
+ instrumentation.logSumOfWeights(metadata.weightedNumExamples)
+ case None =>
+ logInfo(s"numFeatures: ${metadata.numFeatures}")
+ logInfo(s"numClasses: ${metadata.numClasses}")
+ logInfo(s"numExamples: ${metadata.numExamples}")
+ logInfo(s"weightedNumExamples: ${metadata.weightedNumExamples}")
+ }
+
+ timer.start("init")
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug(s"max memory usage for aggregates = $maxMemoryUsage bytes.")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ var nodeIds: RDD[Array[Int]] = null
+ var nodeIdCheckpointer: PeriodicRDDCheckpointer[Array[Int]] = null
+ if (strategy.useNodeIdCache) {
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ nodeIds = baggedInput.map { _ => Array.fill(numTrees)(1) }
+ nodeIdCheckpointer = new PeriodicRDDCheckpointer[Array[Int]](
+ strategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ /*
+ Stack of nodes to train: (treeIndex, node)
+ The reason this is a stack is that we train many trees at once, but we want to focus on
+ completing trees, rather than training all simultaneously. If we are splitting nodes from
+ 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
+ training the same tree in the next iteration. This focus allows us to send fewer trees to
+ workers on each iteration; see topNodesForGroup below.
+ */
+ val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ for (treeIndex <- 0 until numTrees) {
+ nodeStack.prepend((treeIndex, topNodes(treeIndex)))
+ }
+
+ timer.stop("init")
+
+ while (nodeStack.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ RandomForestRaw.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Only send trees to worker if they contain nodes being split this iteration.
+ val topNodesForGroup: Map[Int, LearningNode] =
+ nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ val bestSplit = RandomForestRaw.findBestSplits(baggedInput, metadata, topNodesForGroup,
+ nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds,
+ outputBestSplits = strategy.useNodeIdCache)
+ if (strategy.useNodeIdCache) {
+ nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit)
+ nodeIdCheckpointer.update(nodeIds)
+ }
+
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ if (strategy.useNodeIdCache) {
+ // Delete any remaining checkpoints used for node Id cache.
+ nodeIdCheckpointer.unpersistDataSet()
+ nodeIdCheckpointer.deleteAllCheckpoints()
+ }
+
+ val numFeatures = metadata.numFeatures
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
+ }
+ }
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
+ strategy.getNumClasses)
+ }
+ } else {
+ topNodes.map(rootNode =>
+ new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
+ }
+ }
+ }
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `Instance`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[Instance],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ instr: Option[Instrumentation],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
+ parentUID: Option[String] = None): Array[DecisionTreeModel] = {
+ val timer = new TimeTracker()
+
+ timer.start("build metadata")
+ val metadata = DecisionTreeMetadata
+ .buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
+ timer.stop("build metadata")
+
+ val retaggedInput = input.retag(classOf[Instance])
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplits")
+ val splits = findSplits(retaggedInput, metadata, seed)
+ timer.stop("findSplits")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
+
+ val bcSplits = input.sparkContext.broadcast(splits)
+ val baggedInput = BaggedPoint
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap,
+ (tp: TreePoint) => tp.weight, seed = seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ .setName("bagged tree points")
+
+ val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits,
+ strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy,
+ seed = seed, instr = instr, prune = prune, parentUID = parentUID)
+
+ baggedInput.unpersist()
+ bcSplits.destroy()
+
+ trees
+ }
+
+ /**
+ * Update node indices by newly found splits.
+ */
+ private def updateNodeIds(
+ input: RDD[BaggedPoint[TreePoint]],
+ nodeIds: RDD[Array[Int]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = {
+ require(nodeIds != null && bestSplits != null)
+ input.zip(nodeIds).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < bestSplits.length) {
+ val bestSplitsInTree = bestSplits(treeId)
+ if (bestSplitsInTree != null) {
+ val nodeId = ids(treeId)
+ bestSplitsInTree.get(nodeId).foreach { bestSplit =>
+ val featureId = bestSplit.featureIndex
+ val bin = point.datum.binnedFeatures(featureId)
+ val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) {
+ LearningNode.leftChildIndex(nodeId)
+ } else {
+ LearningNode.rightChildIndex(nodeId)
+ }
+ ids(treeId) = newNodeId
+ }
+ }
+ treeId += 1
+ }
+ ids
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ splits: Array[Array[Split]],
+ unorderedFeatures: Set[Int],
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.length
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures(featureIndex)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
+ // Update the left or right bin for each split.
+ val numSplits = agg.metadata.numSplits(featureIndex)
+ val featureSplits = splits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
+ }
+ featureIndexIdx += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ numSamples: Int,
+ sampleWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val label = treePoint.label
+
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.length) {
+ val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
+ featureIndex += 1
+ }
+ }
+ }
+
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[TreePoint]]
+ * @param metadata Learning and dataset metadata
+ * @param topNodesForGroup For each tree in group, tree index -> root node.
+ * Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ * @param bcSplits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIds an RDD of Array[Int] where each value in the array is the data
+ * point's node Id for a corresponding tree. This is used to prevent
+ * the need to pass the entire tree to the executors during the node
+ * stat aggregation phase.
+ */
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePoint]],
+ metadata: DecisionTreeMetadata,
+ topNodesForGroup: Map[Int, LearningNode],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ bcSplits: Broadcast[Array[Array[Split]]],
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ timer: TimeTracker = new TimeTracker,
+ nodeIds: RDD[Array[Int]] = null,
+ outputBestSplits: Boolean = false): Array[Map[Int, Split]] = {
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ val useNodeIdCache = nodeIds != null
+
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug(s"numNodes = $numNodes")
+ logDebug(s"numFeatures = ${metadata.numFeatures}")
+ logDebug(s"numClasses = ${metadata.numClasses}")
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+ logDebug(s"isMulticlassWithCategoricalFeatures = " +
+ s"${metadata.isMulticlassWithCategoricalFeatures}")
+ logDebug(s"using nodeIdCache = $useNodeIdCache")
+
+ /*
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint],
+ splits: Array[Array[Split]]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
+ }
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
+ }
+ }
+
+ /*
+ * Performs a sequential aggregation over a partition.
+ *
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return agg
+ */
+ def binSeqOp(
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint],
+ splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val nodeIndex =
+ topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int]),
+ splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
+ agg, baggedPoint, splits)
+ }
+ agg
+ }
+
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group.
+ */
+ def getNodeToFeatures(
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+ if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield a (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+ val partitionAggregates = if (useNodeIdCache) {
+
+ input.zip(nodeIds).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _, bcSplits.value))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
+ }
+ } else {
+ input.mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _, bcSplits.value))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: ImpurityStats) =
+ binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats))
+ }.collectAsMap()
+ nodeToFeaturesBc.destroy()
+
+ timer.stop("chooseSplits")
+
+ val bestSplits = if (outputBestSplits) {
+ Array.ofDim[mutable.Map[Int, Split]](metadata.numTrees)
+ } else {
+ null
+ }
+
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: ImpurityStats) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug(s"best split = $split")
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.isLeaf = isLeaf
+ node.stats = stats
+ logDebug(s"Node = $node")
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
+
+ if (outputBestSplits) {
+ val bestSplitsInTree = bestSplits(treeIndex)
+ if (bestSplitsInTree == null) {
+ bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split)
+ } else {
+ bestSplitsInTree.update(nodeIndex, split)
+ }
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeStack.prepend((treeIndex, node.rightChild.get))
+ }
+
+ logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
+ s", impurity = ${stats.leftImpurity}")
+ logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
+ s", impurity = ${stats.rightImpurity}")
+ }
+ }
+ }
+
+ if (outputBestSplits) {
+ bestSplits.map { m => if (m == null) null else m.toMap }
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Calculate the impurity statistics for a given (feature, split) based upon left/right
+ * aggregates.
+ *
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
+ */
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ val totalCount = leftCount + rightCount
+
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
+ }
+
+ /**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+
+ // Calculate InformationGain and ImpurityStats if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
+ null
+ } else {
+ node.stats
+ }
+
+ val validFeatureSplits =
+ Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ val splitsAndImpurityInfo =
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIdx =>
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numCategories).map { featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+
+ logDebug(s"Centroids for categorical variable: " +
+ s"${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug(s"Sorted centroids for categorical variable = " +
+ s"${categoriesSortedByCentroid.mkString(",")}")
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ }
+
+ val (bestSplit, bestSplitStats) =
+ if (splitsAndImpurityInfo.isEmpty) {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+ val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+ (new ContinuousSplit(dummyFeatureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex)
+ (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ } else {
+ splitsAndImpurityInfo.maxBy(_._2.gain)
+ }
+ (bestSplit, bestSplitStats)
+ }
+
+ /**
+ * Returns splits for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[Instance]]
+ * @param metadata Learning and dataset metadata
+ * @param seed random seed
+ * @return Splits, an Array of [[Split]]
+ * of size (numFeatures, numSplits)
+ */
+ protected[tree] def findSplits(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ seed: Long): Array[Array[Split]] = {
+
+ logDebug(s"isMulticlass = ${metadata.isMulticlass}")
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
+ val fraction = samplesFractionForFindSplits(metadata)
+ logDebug(s"fraction of data used for calculating quantiles = $fraction")
+ if (fraction < 1) {
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+ } else {
+ input
+ }
+ } else {
+ input.sparkContext.emptyRDD[Instance]
+ }
+
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
+ }
+
+ private def findSplitsBySorting(
+ input: RDD[Instance],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits = if (continuousFeatures.nonEmpty) {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input.flatMap { point =>
+ continuousFeatures.iterator
+ .map(idx => (idx, (point.features(idx), point.weight)))
+ .filter(_._2._1 != 0.0)
+ }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)(
+ seqOp = { case ((map, c), (v, w)) =>
+ map.changeValue(v, w, _ + w)
+ (map, c + 1L)
+ },
+ combOp = { case ((map1, c1), (map2, c2)) =>
+ map2.foreach { case (v, w) =>
+ map1.changeValue(v, w, _ + w)
+ }
+ (map1, c1 + c2)
+ }
+ ).map { case (idx, (map, c)) =>
+ val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ } else Map.empty[Int, Array[Split]]
+
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
+ metadata.setNumSplits(i, split.length)
+ split
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
+ }
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Splits are constructed as needed during training.
+ Array.empty[Split]
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param featureSamples feature values and sample weights of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Iterable[(Double, Double)],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ val valueWeights = new OpenHashMap[Double, Double]
+ var count = 0L
+ featureSamples.foreach { case (weight, value) =>
+ valueWeights.changeValue(value, weight, _ + weight)
+ count += 1L
+ }
+ findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex)
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ *
+ * @param partValueWeights non-zero distinct values and their weights
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of split thresholds
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ partValueWeights: Map[Double, Double],
+ count: Long,
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = if (partValueWeights.isEmpty) {
+ Array.emptyDoubleArray
+ } else {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ val partNumSamples = partValueWeights.values.sum
+
+ // Calculate the expected number of samples for finding splits
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
+ // scale tolerance by number of samples with constant factor
+ // Note: constant factor was tuned by running some tests where there were no zero
+ // feature values and validating we are never within tolerance
+ val tolerance = Utils.EPSILON * count * 100
+ // add expected zero value count and get complete statistics
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueWeights + (0.0 -> (weightedNumSamples - partNumSamples))
+ } else {
+ partValueWeights
+ }
+
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ val possibleSplits = valueCounts.length - 1
+ if (possibleSplits == 0) {
+ // constant feature
+ Array.emptyDoubleArray
+ } else if (possibleSplits <= numSplits) {
+ // if possible splits is not enough or just enough, just return all possible splits
+ (1 to possibleSplits)
+ .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+ .toArray
+ } else {
+ // stride between splits
+ val stride: Double = weightedNumSamples / (numSplits + 1)
+ logDebug(s"stride = $stride")
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+ splits
+ }
+
+ private[tree] class NodeIndexInfo(
+ val nodeIndexInGroup: Int,
+ val featureSubset: Option[Array[Int]]) extends Serializable
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeStack Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeStack: mutable.ListBuffer[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ var groupDone = false
+ while (nodeStack.nonEmpty && !groupDone) {
+ val (treeIndex, node) = nodeStack.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = RandomForestRaw.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
+ nodeStack.remove(0)
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ } else {
+ groupDone = true
+ }
+ }
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ *
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000..67b9166
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,561 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.{Param, Params}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.{col, lit, struct}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.VersionUtils
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Abstraction for Decision Tree models.
+ */
+private[spark] trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ *
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
+
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
+ private[spark] def toOld: OldDecisionTreeModel
+
+ /**
+ * @return an iterator that traverses (DFS, left to right) the leaves
+ * in the subtree of this node.
+ */
+ private def leafIterator(node: Node): Iterator[LeafNode] = {
+ node match {
+ case l: LeafNode => Iterator.single(l)
+ case n: InternalNode =>
+ leafIterator(n.leftChild) ++ leafIterator(n.rightChild)
+ }
+ }
+
+ private[ml] lazy val numLeave: Int =
+ leafIterator(rootNode).size
+
+ private[ml] lazy val leafAttr = {
+ NominalAttribute.defaultAttr
+ .withNumValues(numLeave)
+ }
+
+ private[ml] def getLeafField(leafCol: String) = {
+ leafAttr.withName(leafCol).toStructField()
+ }
+
+ @transient private lazy val leafIndices: Map[LeafNode, Int] = {
+ leafIterator(rootNode).zipWithIndex.toMap
+ }
+
+ /**
+ * @return The index of the leaf corresponding to the feature vector.
+ * Leaves are indexed in pre-order from 0.
+ */
+ def predictLeaf(features: Vector): Double = {
+ leafIndices(rootNode.predictImpl(features)).toDouble
+ }
+}
+
+/**
+ * Abstraction for models which are ensembles of decision trees
+ * @tparam M Type of tree model in this ensemble
+ */
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
+
+ // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
+ // DecisionTreeModel.
+
+ /** Trees in this ensemble. Warning: These have null parent Estimators. */
+ def trees: Array[M]
+
+ /** Weights for each tree, zippable with [[trees]] */
+ def treeWeights: Array[Double]
+
+ /** Weights used by the python wrappers. */
+ // Note: An array cannot be returned directly due to serialization problems.
+ private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"TreeEnsembleModel with ${trees.length} trees"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) =>
+ s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /** Total number of nodes, summed over all trees in the ensemble. */
+ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
+
+ /**
+ * @return The indices of the leaves corresponding to the feature vector.
+ * Leaves are indexed in pre-order from 0.
+ */
+ def predictLeaf(features: Vector): Vector = {
+ val indices = trees.map(_.predictLeaf(features))
+ Vectors.dense(indices)
+ }
+
+ private[ml] def getLeafField(leafCol: String) = {
+ new AttributeGroup(leafCol, attrs = trees.map(_.leafAttr)).toStructField()
+ }
+}
+
+private[ml] object TreeEnsembleModel {
+
+ /**
+ * Given a tree ensemble model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1 (only if `perTreeNormalization` is `true`).
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * References:
+ * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+ *
+ * @param trees Unweighted collection of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @param perTreeNormalization By default this is set to `true` and it means that the importances
+ * of each tree are normalized before being summed. If set to `false`,
+ * the normalization is skipped.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel](
+ trees: Array[M],
+ numFeatures: Int,
+ perTreeNormalization: Boolean = true): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = if (perTreeNormalization) {
+ importances.map(_._2).sum
+ } else {
+ // We won't use it
+ Double.NaN
+ }
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = if (perTreeNormalization) {
+ impt / treeNorm
+ } else {
+ impt
+ }
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ *
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ *
+ * @param map Map with non-negative values.
+ */
+ def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+}
+
+/** Helper classes for tree model persistence */
+private[ml] object DecisionTreeModelReadWrite {
+
+ /**
+ * Info for a [[org.apache.spark.ml.tree.Split]]
+ *
+ * @param featureIndex Index of feature split on
+ * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories.
+ * For continuous feature, threshold.
+ * @param numCategories For categorical feature, number of categories.
+ * For continuous feature, -1.
+ */
+ case class SplitData(
+ featureIndex: Int,
+ leftCategoriesOrThreshold: Array[Double],
+ numCategories: Int) {
+
+ def getSplit: Split = {
+ if (numCategories != -1) {
+ new CategoricalSplit(featureIndex, leftCategoriesOrThreshold, numCategories)
+ } else {
+ assert(leftCategoriesOrThreshold.length == 1, s"DecisionTree split data expected" +
+ s" 1 threshold for ContinuousSplit, but found thresholds: " +
+ leftCategoriesOrThreshold.mkString(", "))
+ new ContinuousSplit(featureIndex, leftCategoriesOrThreshold(0))
+ }
+ }
+ }
+
+ object SplitData {
+ def apply(split: Split): SplitData = split match {
+ case s: CategoricalSplit =>
+ SplitData(s.featureIndex, s.leftCategories, s.numCategories)
+ case s: ContinuousSplit =>
+ SplitData(s.featureIndex, Array(s.threshold), -1)
+ }
+ }
+
+ /**
+ * Info for a [[Node]]
+ *
+ * @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
+ * @param impurityStats Stats array. Impurity type is stored in metadata.
+ * @param rawCount The unweighted number of samples falling in this node.
+ * @param gain Gain, or arbitrary value if leaf node.
+ * @param leftChild Left child index, or arbitrary value if leaf node.
+ * @param rightChild Right child index, or arbitrary value if leaf node.
+ * @param split Split info, or arbitrary value if leaf node.
+ */
+ case class NodeData(
+ id: Int,
+ prediction: Double,
+ impurity: Double,
+ impurityStats: Array[Double],
+ rawCount: Long,
+ gain: Double,
+ leftChild: Int,
+ rightChild: Int,
+ split: SplitData)
+
+ object NodeData {
+ /**
+ * Create [[NodeData]] instances for this node and all children.
+ *
+ * @param id Current ID. IDs are assigned via a pre-order traversal.
+ * @return (sequence of nodes in pre-order traversal order, largest ID in subtree)
+ * The nodes are returned in pre-order traversal (root first) so that it is easy to
+ * get the ID of the subtree's root node.
+ */
+ def build(node: Node, id: Int): (Seq[NodeData], Int) = node match {
+ case n: InternalNode =>
+ val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
+ val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
+ val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
+ n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
+ SplitData(n.split))
+ (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
+ case _: LeafNode =>
+ (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
+ node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.emptyDoubleArray, -1))),
+ id)
+ }
+ }
+
+ /**
+ * Load a decision tree from a file.
+ * @return Root node of reconstructed tree
+ */
+ def loadTreeNodes(
+ path: String,
+ metadata: DefaultParamsReader.Metadata,
+ sparkSession: SparkSession): Node = {
+ import sparkSession.implicits._
+ implicit val format = DefaultFormats
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val dataPath = new Path(path, "data").toString
+ var df = sparkSession.read.parquet(dataPath)
+ val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ if (major.toInt < 3) {
+ df = df.withColumn("rawCount", lit(-1L))
+ }
+
+ buildTreeFromNodes(df.as[NodeData].collect(), impurityType)
+ }
+
+ /**
+ * Given all data for all nodes in a tree, rebuild the tree.
+ * @param data Unsorted node data
+ * @param impurityType Impurity type for this tree
+ * @return Root node of reconstructed tree
+ */
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
+ // Load all nodes, sorted by ID.
+ val nodes = data.sortBy(_.id)
+ // Sanity checks; could remove
+ assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
+ s" but found ${nodes.head.id}")
+ assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" +
+ s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}")
+ // We fill `finalNodes` in reverse order. Since node IDs are assigned via a pre-order
+ // traversal, this guarantees that child nodes will be built before parent nodes.
+ val finalNodes = new Array[Node](nodes.length)
+ nodes.reverseIterator.foreach { case n: NodeData =>
+ val impurityStats =
+ ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
+ val node = if (n.leftChild != -1) {
+ val leftChild = finalNodes(n.leftChild)
+ val rightChild = finalNodes(n.rightChild)
+ new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild,
+ n.split.getSplit, impurityStats)
+ } else {
+ new LeafNode(n.prediction, n.impurity, impurityStats)
+ }
+ finalNodes(n.id) = node
+ }
+ // Return the root node
+ finalNodes.head
+ }
+}
+
+private[ml] object EnsembleModelReadWrite {
+
+ /**
+ * Helper method for saving a tree ensemble to disk.
+ *
+ * @param instance Tree ensemble model
+ * @param path Path to which to save the ensemble model.
+ * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
+ */
+ def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
+ instance: M,
+ path: String,
+ sql: SparkSession,
+ extraMetadata: JObject): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
+ val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
+ case (tree, treeID) =>
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
+ instance.treeWeights(treeID))
+ }
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
+ .write.parquet(treesMetadataPath)
+ val dataPath = new Path(path, "data").toString
+ val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
+ case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
+ }
+ sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ }
+
+ /**
+ * Helper method for loading a tree ensemble from disk.
+ * This reconstructs all trees, returning the root nodes.
+ * @param path Path given to `saveImpl`
+ * @param className Class name for ensemble model type
+ * @param treeClassName Class name for tree model type in the ensemble
+ * @return (ensemble metadata, array over trees of (tree metadata, root node)),
+ * where the root node is linked with all descendents
+ * @see `saveImpl` for how the model was saved
+ */
+ def loadImpl(
+ path: String,
+ sql: SparkSession,
+ className: String,
+ treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
+ import sql.implicits._
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ val treesMetadataRDD = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata", "weights")
+ .as[(Int, String, Double)].rdd
+ .map { case (treeID: Int, json: String, weights: Double) =>
+ treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName), weights))
+ }
+
+ val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+ val treesMetadata = treesMetadataWeights.map(_._1)
+ val treesWeights = treesMetadataWeights.map(_._2)
+
+ val dataPath = new Path(path, "data").toString
+ var df = sql.read.parquet(dataPath)
+ val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ if (major.toInt < 3) {
+ val newNodeDataCol = df.schema("nodeData").dataType match {
+ case StructType(fields) =>
+ val cols = fields.map(f => col(s"nodeData.${f.name}")) :+ lit(-1L).as("rawCount")
+ struct(cols: _*)
+ }
+ df = df.withColumn("nodeData", newNodeDataCol)
+ }
+
+ val rootNodesRDD = df.as[EnsembleNodeData].rdd
+ .map(d => (d.treeID, d.nodeData))
+ .groupByKey()
+ .map { case (treeID: Int, nodeData: Iterable[NodeData]) =>
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ }
+ val rootNodes = rootNodesRDD.sortByKey().values.collect()
+ (metadata, treesMetadata.zip(rootNodes), treesWeights)
+ }
+
+ /**
+ * Info for one [[Node]] in a tree ensemble
+ *
+ * @param treeID Tree index
+ * @param nodeData Data for this node
+ */
+ case class EnsembleNodeData(
+ treeID: Int,
+ nodeData: NodeData)
+
+ object EnsembleNodeData {
+ /**
+ * Create [[EnsembleNodeData]] instances for the given tree.
+ *
+ * @return Sequence of nodes for this tree
+ */
+ def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = {
+ val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0)
+ nodeData.map(nd => EnsembleNodeData(treeID, nd))
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
new file mode 100644
index 0000000..acb843a
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -0,0 +1,626 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import java.util.Locale
+
+import scala.util.Try
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.classification.ProbabilisticClassifierParams
+import org.apache.spark.ml.linalg.VectorUDT
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+/**
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private since this may be made public in the future.
+ */
+private[ml] trait DecisionTreeParams extends PredictorParams
+ with HasCheckpointInterval with HasSeed with HasWeightCol {
+
+ /**
+ * Leaf indices column name.
+ * Predicted leaf index of each instance in each tree by preorder.
+ * (default = "")
+ * @group param
+ */
+ @Since("3.0.0")
+ final val leafCol: Param[String] =
+ new Param[String](this, "leafCol", "Leaf indices column name. " +
+ "Predicted leaf index of each instance in each tree by preorder")
+
+ /**
+ * Maximum depth of the tree (nonnegative).
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree. (Nonnegative)" +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
+ ParamValidators.gtEq(0))
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be at least 2 and at least number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be at least 2 and at least number of categories" +
+ " for any categorical feature.", ParamValidators.gtEq(2))
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Must be at least 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Must be at least 1.", ParamValidators.gtEq(1))
+
+ /**
+ * Minimum fraction of the weighted sample count that each child must have after split.
+ * If a split causes the fraction of the total weight in the left or right child to be less than
+ * minWeightFractionPerNode, the split will be discarded as invalid.
+ * Should be in the interval [0.0, 0.5).
+ * (default = 0.0)
+ * @group param
+ */
+ final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
+ "minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
+ "must have after split. If a split causes the fraction of the total weight in the left or " +
+ "right child to be less than minWeightFractionPerNode, the split will be discarded as " +
+ "invalid. Should be in interval [0.0, 0.5)",
+ ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * Should be at least 0.0.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.",
+ ParamValidators.gtEq(0.0))
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be
+ * split per iteration, and its aggregates may exceed this size.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.",
+ ParamValidators.gtEq(0))
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees. Users can set how often should the
+ * cache be checkpointed or disable it by setting checkpointInterval.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ setDefault(leafCol -> "", maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
+ minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
+ cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ @Since("3.0.0")
+ final def setLeafCol(value: String): this.type = set(leafCol, value)
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getLeafCol: String = $(leafCol)
+
+ /** @group getParam */
+ final def getMaxDepth: Int = $(maxDepth)
+
+ /** @group getParam */
+ final def getMaxBins: Int = $(maxBins)
+
+ /** @group getParam */
+ final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+
+ /** @group getParam */
+ final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
+
+ /** @group getParam */
+ final def getMinInfoGain: Double = $(minInfoGain)
+
+ /** @group expertGetParam */
+ final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
+
+ /** @group expertGetParam */
+ final def getCacheNodeIds: Boolean = $(cacheNodeIds)
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity,
+ subsamplingRate: Double): OldStrategy = {
+ val strategy = OldStrategy.defaultStrategy(oldAlgo)
+ strategy.impurity = oldImpurity
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = subsamplingRate
+ strategy
+ }
+}
+
+/**
+ * Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * This impurity type is used in DecisionTreeClassifier and RandomForestClassifier,
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
+ (value: String) =>
+ TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
+
+ setDefault(impurity -> "gini")
+
+ /** @group getParam */
+ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ final val supportedImpurities: Array[String] =
+ Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT))
+}
+
+private[ml] trait DecisionTreeClassifierParams
+ extends DecisionTreeParams with TreeClassifierParams with ProbabilisticClassifierParams {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), DoubleType)
+ }
+ outputSchema
+ }
+}
+
+private[ml] trait HasVarianceImpurity extends Params {
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * This impurity type is used in DecisionTreeRegressor, RandomForestRegressor, GBTRegressor
+ * and GBTClassifier (since GBTClassificationModel is internally composed of
+ * DecisionTreeRegressionModels).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
+ (value: String) =>
+ HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
+
+ setDefault(impurity -> "variance")
+
+ /** @group getParam */
+ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object HasVarianceImpurity {
+ // These options should be lowercase.
+ final val supportedImpurities: Array[String] =
+ Array("variance").map(_.toLowerCase(Locale.ROOT))
+}
+
+/**
+ * Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends HasVarianceImpurity
+
+private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
+ with TreeRegressorParams with HasVarianceCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ outputSchema = SchemaUtils.appendColumn(outputSchema, $(varianceCol), DoubleType)
+ }
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), DoubleType)
+ }
+ outputSchema
+ }
+}
+
+private[spark] object TreeEnsembleParams {
+ // These options should be lowercase.
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
+}
+
+/**
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private since this may be made public in the future.
+ */
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
+
+ /**
+ * Fraction of the training data used for learning each decision tree, in range (0, 1].
+ * (default = 1.0)
+ * @group param
+ */
+ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree, in range (0, 1].",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ /** @group getParam */
+ final def getSubsamplingRate: Double = $(subsamplingRate)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and seed.
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+ }
+
+ /**
+ * The number of features to consider for splits at each tree node.
+ * Supported options:
+ * - "auto": Choose automatically for task:
+ * If numTrees == 1, set to "all."
+ * If numTrees greater than 1 (forest), set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * - "all": use all features
+ * - "onethird": use 1/3 of the features
+ * - "sqrt": use sqrt(number of features)
+ * - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
+ * (default = "auto")
+ *
+ * These various settings are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ * @see Breiman (2001)
+ * @see
+ * Breiman manual for random forests
+ *
+ * @group param
+ */
+ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node." +
+ s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
+ s", (0.0-1.0], [1-n].",
+ (value: String) =>
+ TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
+ value.toLowerCase(Locale.ROOT))
+ || Try(value.toInt).filter(_ > 0).isSuccess
+ || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
+
+ /** @group getParam */
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
+
+ setDefault(subsamplingRate -> 1.0, featureSubsetStrategy -> "auto")
+}
+
+/**
+ * Parameters for Decision Tree-based ensemble classification algorithms.
+ */
+private[ml] trait TreeEnsembleClassifierParams
+ extends TreeEnsembleParams with ProbabilisticClassifierParams {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), new VectorUDT)
+ }
+ outputSchema
+ }
+}
+
+/**
+ * Parameters for Decision Tree-based ensemble regression algorithms.
+ */
+private[ml] trait TreeEnsembleRegressorParams
+ extends TreeEnsembleParams {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ var outputSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.appendColumn(outputSchema, $(leafCol), new VectorUDT)
+ }
+ outputSchema
+ }
+}
+
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (at least 1).
+ * If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ *
+ * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
+ * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
+ * are a bit different.
+ * @group param
+ */
+ final val numTrees: IntParam =
+ new IntParam(this, "numTrees", "Number of trees to train (at least 1)",
+ ParamValidators.gtEq(1))
+
+ /** @group getParam */
+ final def getNumTrees: Int = $(numTrees)
+
+ /**
+ * Whether bootstrap samples are used when building trees.
+ * @group expertParam
+ */
+ @Since("3.0.0")
+ final val bootstrap: BooleanParam = new BooleanParam(this, "bootstrap",
+ "Whether bootstrap samples are used when building trees.")
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getBootstrap: Boolean = $(bootstrap)
+
+ setDefault(numTrees -> 20, bootstrap -> true)
+}
+
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeEnsembleClassifierParams with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeEnsembleRegressorParams with TreeRegressorParams
+
+/**
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private since this may be made public in the future.
+ */
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize
+ with HasValidationIndicatorCol {
+
+ /**
+ * Threshold for stopping early when fit with validation is used.
+ * (This parameter is ignored when fit without validation is used.)
+ * The decision to stop early is decided based on this logic:
+ * If the current loss on the validation set is greater than 0.01, the diff
+ * of validation error is compared to relative tolerance which is
+ * validationTol * (current loss on the validation set).
+ * If the current loss on the validation set is less than or equal to 0.01,
+ * the diff of validation error is compared to absolute tolerance which is
+ * validationTol * 0.01.
+ * @group param
+ * @see validationIndicatorCol
+ */
+ @Since("2.4.0")
+ final val validationTol: DoubleParam = new DoubleParam(this, "validationTol",
+ "Threshold for stopping early when fit with validation is used." +
+ "If the error rate on the validation input changes by less than the validationTol," +
+ "then learning will stop early (before `maxIter`)." +
+ "This parameter is ignored when fit without validation is used.",
+ ParamValidators.gtEq(0.0)
+ )
+
+ /** @group getParam */
+ @Since("2.4.0")
+ final def getValidationTol: Double = $(validationTol)
+
+ /**
+ * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
+ * the contribution of each estimator.
+ * (default = 0.1)
+ * @group param
+ */
+ final override val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
+ "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01, featureSubsetStrategy -> "all")
+
+ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+ private[ml] def getOldBoostingStrategy(
+ categoricalFeatures: Map[Int, Int],
+ oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+ // NOTE: The old API does not support "seed" so we ignore it.
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol)
+ }
+
+ /** Get old Gradient Boosting Loss type */
+ private[ml] def getOldLossType: OldLoss
+
+ final val doUseAcc: BooleanParam = new BooleanParam(this, "doUseAcc",
+ "If true, use the optimized algorithm; otherwise, use the raw version")
+
+ var setUseAccFlag = false
+
+ /** Set algorithm to the raw version. */
+ def setDoUseAcc(value: Boolean): this.type = {
+ setUseAccFlag = true
+ set(doUseAcc, value)
+ }
+ setDefault(doUseAcc -> true)
+
+ /** Get algorithm type. */
+ def getDoUseAcc: (Boolean, Boolean) = ($(doUseAcc), setUseAccFlag)
+}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] =
+ Array("logistic").map(_.toLowerCase(Locale.ROOT))
+}
+
+private[ml] trait GBTClassifierParams
+ extends GBTParams with TreeEnsembleClassifierParams with HasVarianceImpurity {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) =>
+ GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT)))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase(Locale.ROOT)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldClassificationLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] =
+ Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT))
+}
+
+private[ml] trait GBTRegressorParams
+ extends GBTParams with TreeEnsembleRegressorParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) =>
+ GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT)))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase(Locale.ROOT)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ convertToOldLossType(getLossType)
+ }
+
+ private[ml] def convertToOldLossType(loss: String): OldLoss = {
+ loss match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
new file mode 100644
index 0000000..52555a2
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -0,0 +1,666 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature
+
+import java.lang.{Iterable => JavaIterable}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.annotation.Since
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd._
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.BoundedPriorityQueue
+import org.apache.spark.util.Utils
+
+/**
+ * Entry in vocabulary
+ */
+private case class VocabWord(
+ var word: String,
+ var cn: Int,
+ var point: Array[Int],
+ var code: Array[Int],
+ var codeLen: Int
+)
+
+/**
+ * Word2Vec creates vector representation of words in a text corpus.
+ * The algorithm first constructs a vocabulary from the corpus
+ * and then learns vector representation of words in the vocabulary.
+ * The vector representation can be used as features in
+ * natural language processing and machine learning algorithms.
+ *
+ * We used skip-gram model in our implementation and hierarchical softmax
+ * method to train the model. The variable names in the implementation
+ * matches the original C implementation.
+ *
+ * For original C implementation, see https://code.google.com/p/word2vec/
+ * For research papers, see
+ * Efficient Estimation of Word Representations in Vector Space
+ * and
+ * Distributed Representations of Words and Phrases and their Compositionality.
+ */
+@Since("1.1.0")
+class Word2Vec extends Serializable with Logging {
+
+ private var vectorSize = 100
+ private var learningRate = 0.025
+ private var regularization = 0.05f
+ private var repetition = 3
+ private var numPartitions = 1
+ private var numIterations = 1
+ private var seed = Utils.random.nextLong()
+ private var minCount = 5
+ private var maxSentenceLength = 1000
+
+ /**
+ * Sets the regularization coefficient.
+ */
+ def setRegularization(regularization: Float): this.type = {
+ this.regularization = regularization
+ this
+ }
+
+ /**
+ * Sets the number of repetitions of data.
+ */
+ def setRepetition(repetition: Int): this.type = {
+ require(repetition >= 0,
+ s"The number of repetitions of data must not be smaller than 0 but got ${repetition}")
+ this.repetition = repetition
+ this
+ }
+
+ /**
+ * Sets the maximum length (in words) of each sentence in the input data.
+ * Any sentence longer than this threshold will be divided into chunks of
+ * up to `maxSentenceLength` size (default: 1000)
+ */
+ @Since("2.0.0")
+ def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
+ require(maxSentenceLength > 0,
+ s"Maximum length of sentences must be positive but got ${maxSentenceLength}")
+ this.maxSentenceLength = maxSentenceLength
+ this
+ }
+
+ /**
+ * Sets vector size (default: 100).
+ */
+ @Since("1.1.0")
+ def setVectorSize(vectorSize: Int): this.type = {
+ require(vectorSize > 0,
+ s"vector size must be positive but got ${vectorSize}")
+ this.vectorSize = vectorSize
+ this
+ }
+
+ /**
+ * Sets initial learning rate (default: 0.025).
+ */
+ @Since("1.1.0")
+ def setLearningRate(learningRate: Double): this.type = {
+ require(learningRate > 0,
+ s"Initial learning rate must be positive but got ${learningRate}")
+ this.learningRate = learningRate
+ this
+ }
+
+ /**
+ * Sets number of partitions (default: 1). Use a small number for accuracy.
+ */
+ @Since("1.1.0")
+ def setNumPartitions(numPartitions: Int): this.type = {
+ require(numPartitions > 0,
+ s"Number of partitions must be positive but got ${numPartitions}")
+ this.numPartitions = numPartitions
+ this
+ }
+
+ /**
+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
+ * partitions.
+ */
+ @Since("1.1.0")
+ def setNumIterations(numIterations: Int): this.type = {
+ require(numIterations >= 0,
+ s"Number of iterations must be nonnegative but got ${numIterations}")
+ this.numIterations = numIterations
+ this
+ }
+
+ /**
+ * Sets random seed (default: a random long integer).
+ */
+ @Since("1.1.0")
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
+ /**
+ * Sets the window of words (default: 5)
+ */
+ @Since("1.6.0")
+ def setWindowSize(window: Int): this.type = {
+ require(window > 0,
+ s"Window of words must be positive but got ${window}")
+ this.window = window
+ this
+ }
+
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ * model's vocabulary (default: 5).
+ */
+ @Since("1.3.0")
+ def setMinCount(minCount: Int): this.type = {
+ require(minCount >= 0,
+ s"Minimum number of times must be nonnegative but got ${minCount}")
+ this.minCount = minCount
+ this
+ }
+
+ /**
+ * Set parameters of algorithm.
+ */
+ private def setParams(conf: SparkConf): Unit = {
+ setRegularization(Word2Vec.getSparkConfFloatValue(
+ conf, "spark.boostkit.mllib.feature.word2vec.regularization", 0.05f))
+ setRepetition(Word2Vec.getSparkConfIntValue(
+ conf, "spark.boostkit.mllib.feature.word2vec.repetition", 3))
+ }
+
+ private val EXP_TABLE_SIZE = 1000
+ private val MAX_EXP = 6
+ private val MAX_CODE_LENGTH = 40
+
+ /** context words from [-window, window] */
+ private var window = 5
+
+ private var trainWordsCount = 0L
+ private var vocabSize = 0
+ @transient private var vocab: Array[VocabWord] = null
+ @transient private var vocabHash = mutable.HashMap.empty[String, Int]
+
+ private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {
+ val words = dataset.flatMap(x => x)
+
+ vocab = words.map(w => (w, 1))
+ .reduceByKey(_ + _)
+ .filter(_._2 >= minCount)
+ .map(x => VocabWord(
+ x._1,
+ x._2,
+ new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
+ 0))
+ .collect()
+ .sortWith((a, b) => a.cn > b.cn)
+
+ vocabSize = vocab.length
+ require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
+ "the setting of minCount, which could be large enough to remove all your words in sentences.")
+
+ var a = 0
+ while (a < vocabSize) {
+ vocabHash += vocab(a).word -> a
+ trainWordsCount += vocab(a).cn
+ a += 1
+ }
+ logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
+ }
+
+ private def createExpTable(): Array[Float] = {
+ val expTable = new Array[Float](EXP_TABLE_SIZE)
+ var i = 0
+ while (i < EXP_TABLE_SIZE) {
+ val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
+ expTable(i) = (tmp / (tmp + 1.0)).toFloat
+ i += 1
+ }
+ expTable
+ }
+
+ private def createBinaryTree(): Unit = {
+ val count = new Array[Long](vocabSize * 2 + 1)
+ val binary = new Array[Int](vocabSize * 2 + 1)
+ val parentNode = new Array[Int](vocabSize * 2 + 1)
+ val code = new Array[Int](MAX_CODE_LENGTH)
+ val point = new Array[Int](MAX_CODE_LENGTH)
+ var a = 0
+ while (a < vocabSize) {
+ count(a) = vocab(a).cn
+ a += 1
+ }
+ while (a < 2 * vocabSize) {
+ count(a) = 1e9.toInt
+ a += 1
+ }
+ var pos1 = vocabSize - 1
+ var pos2 = vocabSize
+
+ var min1i = 0
+ var min2i = 0
+
+ a = 0
+ while (a < vocabSize - 1) {
+ if (pos1 >= 0) {
+ if (count(pos1) < count(pos2)) {
+ min1i = pos1
+ pos1 -= 1
+ } else {
+ min1i = pos2
+ pos2 += 1
+ }
+ } else {
+ min1i = pos2
+ pos2 += 1
+ }
+ if (pos1 >= 0) {
+ if (count(pos1) < count(pos2)) {
+ min2i = pos1
+ pos1 -= 1
+ } else {
+ min2i = pos2
+ pos2 += 1
+ }
+ } else {
+ min2i = pos2
+ pos2 += 1
+ }
+ count(vocabSize + a) = count(min1i) + count(min2i)
+ parentNode(min1i) = vocabSize + a
+ parentNode(min2i) = vocabSize + a
+ binary(min2i) = 1
+ a += 1
+ }
+ // Now assign binary code to each vocabulary word
+ var i = 0
+ a = 0
+ while (a < vocabSize) {
+ var b = a
+ i = 0
+ while (b != vocabSize * 2 - 2) {
+ code(i) = binary(b)
+ point(i) = b
+ i += 1
+ b = parentNode(b)
+ }
+ vocab(a).codeLen = i
+ vocab(a).point(0) = vocabSize - 2
+ b = 0
+ while (b < i) {
+ vocab(a).code(i - b - 1) = code(b)
+ vocab(a).point(i - b) = point(b) - vocabSize
+ b += 1
+ }
+ a += 1
+ }
+ }
+
+ /**
+ * Computes the vector representation of each word in vocabulary.
+ * @param dataset an RDD of sentences,
+ * each sentence is expressed as an iterable collection of words
+ * @return a Word2VecModel
+ */
+ @Since("1.1.0")
+ def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
+
+ learnVocab(dataset)
+
+ createBinaryTree()
+
+ val sc = dataset.context
+
+ val expTable = sc.broadcast(createExpTable())
+ val bcVocab = sc.broadcast(vocab)
+ val bcVocabHash = sc.broadcast(vocabHash)
+ try {
+ setParams(dataset.sparkContext.conf)
+
+ val modelSGHS = new Word2VecSGHS(minCount, window, vectorSize, vocabSize, trainWordsCount,
+ learningRate, numIterations, seed, maxSentenceLength, regularization, repetition)
+ val vectors = modelSGHS.fit(dataset, expTable, bcVocab, bcVocabHash)
+
+ new Word2VecModel(vocabHash.toMap, vectors)
+ } finally {
+ expTable.destroy(blocking = false)
+ bcVocab.destroy(blocking = false)
+ bcVocabHash.destroy(blocking = false)
+ }
+ }
+
+ /**
+ * Computes the vector representation of each word in vocabulary (Java version).
+ * @param dataset a JavaRDD of words
+ * @return a Word2VecModel
+ */
+ @Since("1.1.0")
+ def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
+ fit(dataset.rdd.map(_.asScala))
+ }
+}
+
+/**
+ * Word2Vec model
+ * @param wordIndex maps each word to an index, which can retrieve the corresponding
+ * vector from wordVectors
+ * @param wordVectors array of length numWords * vectorSize, vector corresponding
+ * to the word mapped with index i can be retrieved by the slice
+ * (i * vectorSize, i * vectorSize + vectorSize)
+ */
+@Since("1.1.0")
+class Word2VecModel private[spark] (
+ private[spark] val wordIndex: Map[String, Int],
+ private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable {
+
+ private val numWords = wordIndex.size
+ // vectorSize: Dimension of each word's vector.
+ private val vectorSize = wordVectors.length / numWords
+
+ // wordList: Ordered list of words obtained from wordIndex.
+ private val wordList: Array[String] = {
+ val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
+ wl.toArray
+ }
+
+ // wordVecNorms: Array of length numWords, each value being the Euclidean norm
+ // of the wordVector.
+ private val wordVecNorms: Array[Float] = {
+ val wordVecNorms = new Array[Float](numWords)
+ var i = 0
+ while (i < numWords) {
+ val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
+ wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
+ i += 1
+ }
+ wordVecNorms
+ }
+
+ @Since("1.5.0")
+ def this(model: Map[String, Array[Float]]) = {
+ this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
+ }
+
+ @Since("1.4.0")
+ def save(sc: SparkContext, path: String): Unit = {
+ Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
+ }
+
+ /**
+ * Transforms a word to its vector representation
+ * @param word a word
+ * @return vector representation of word
+ */
+ @Since("1.1.0")
+ def transform(word: String): Vector = {
+ wordIndex.get(word) match {
+ case Some(ind) =>
+ val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
+ Vectors.dense(vec.map(_.toDouble))
+ case None =>
+ throw new IllegalStateException(s"$word not in vocabulary")
+ }
+ }
+
+ /**
+ * Find synonyms of a word; do not include the word itself in results.
+ * @param word a word
+ * @param num number of synonyms to find
+ * @return array of (word, cosineSimilarity)
+ */
+ @Since("1.1.0")
+ def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
+ val vector = transform(word)
+ findSynonyms(vector, num, Some(word))
+ }
+
+ /**
+ * Find synonyms of the vector representation of a word, possibly
+ * including any words in the model vocabulary whose vector respresentation
+ * is the supplied vector.
+ * @param vector vector representation of a word
+ * @param num number of synonyms to find
+ * @return array of (word, cosineSimilarity)
+ */
+ @Since("1.1.0")
+ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
+ findSynonyms(vector, num, None)
+ }
+
+ /**
+ * Find synonyms of the vector representation of a word, rejecting
+ * words identical to the value of wordOpt, if one is supplied.
+ * @param vector vector representation of a word
+ * @param num number of synonyms to find
+ * @param wordOpt optionally, a word to reject from the results list
+ * @return array of (word, cosineSimilarity)
+ */
+ private def findSynonyms(
+ vector: Vector,
+ num: Int,
+ wordOpt: Option[String]): Array[(String, Double)] = {
+ require(num > 0, "Number of similar words should > 0")
+
+ val fVector = vector.toArray.map(_.toFloat)
+ val cosineVec = new Array[Float](numWords)
+ val alpha: Float = 1
+ val beta: Float = 0
+ // Normalize input vector before blas.sgemv to avoid Inf value
+ val vecNorm = blas.snrm2(vectorSize, fVector, 1)
+ if (vecNorm != 0.0f) {
+ blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)
+ }
+ blas.sgemv(
+ "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
+
+ var i = 0
+ while (i < numWords) {
+ val norm = wordVecNorms(i)
+ if (norm == 0.0f) {
+ cosineVec(i) = 0.0f
+ } else {
+ cosineVec(i) /= norm
+ }
+ i += 1
+ }
+
+ val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))
+
+ var j = 0
+ while (j < numWords) {
+ pq += Tuple2(wordList(j), cosineVec(j))
+ j += 1
+ }
+
+ val scored = pq.toSeq.sortBy(-_._2)
+
+ val filtered = wordOpt match {
+ case Some(w) => scored.filter(tup => w != tup._1)
+ case None => scored
+ }
+
+ filtered
+ .take(num)
+ .map { case (word, score) => (word, score.toDouble) }
+ .toArray
+ }
+
+ /**
+ * Returns a map of words to their vector representations.
+ */
+ @Since("1.2.0")
+ def getVectors: Map[String, Array[Float]] = {
+ wordIndex.map { case (word, ind) =>
+ (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
+ }
+ }
+
+}
+
+object Word2Vec {
+ /**
+ * Get the float value from spark conf.
+ */
+ def getSparkConfFloatValue(
+ conf: SparkConf,
+ valueName: String,
+ defaultValue: Float): Float = {
+ var value = defaultValue
+ val valueOption = conf.getOption(valueName)
+
+ if(valueOption.nonEmpty) {
+ try {
+ value = valueOption.get.toFloat
+ } catch {
+ case _: Exception =>
+ throw new IllegalArgumentException(s"$valueName should be Float, " +
+ s"but was ${valueOption.get}")
+ }
+ }
+
+ value
+ }
+
+ /**
+ * Get the int value from spark conf.
+ */
+ def getSparkConfIntValue(
+ conf: SparkConf,
+ valueName: String,
+ defaultValue: Int): Int = {
+ var value = defaultValue
+ val valueOption = conf.getOption(valueName)
+
+ if(valueOption.nonEmpty) {
+ try {
+ value = valueOption.get.toInt
+ } catch {
+ case _: Exception =>
+ throw new IllegalArgumentException(s"$valueName should be Int and greater than 0, " +
+ s"but was ${valueOption.get}")
+ }
+ }
+
+ value
+ }
+}
+
+@Since("1.4.0")
+object Word2VecModel extends Loader[Word2VecModel] {
+
+ private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
+ model.keys.zipWithIndex.toMap
+ }
+
+ private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
+ require(model.nonEmpty, "Word2VecMap should be non-empty")
+ val (vectorSize, numWords) = (model.head._2.length, model.size)
+ val wordList = model.keys.toArray
+ val wordVectors = new Array[Float](vectorSize * numWords)
+ var i = 0
+ while (i < numWords) {
+ Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
+ i += 1
+ }
+ wordVectors
+ }
+
+ private object SaveLoadV1_0 {
+
+ val formatVersionV1_0 = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
+
+ case class Data(word: String, vector: Array[Float])
+
+ def load(sc: SparkContext, path: String): Word2VecModel = {
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+ val dataFrame = spark.read.parquet(Loader.dataPath(path))
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val dataArray = dataFrame.select("word", "vector").collect()
+ val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
+ new Word2VecModel(word2VecMap)
+ }
+
+ def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+
+ val vectorSize = model.values.head.length
+ val numWords = model.size
+ val metadata = compact(render(
+ ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
+ ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // We want to partition the model in partitions smaller than
+ // spark.kryoserializer.buffer.max
+ val bufferSize = Utils.byteStringAsBytes(
+ spark.conf.get("spark.kryoserializer.buffer.max", "64m"))
+ // We calculate the approximate size of the model
+ // We only calculate the array size, considering an
+ // average string size of 15 bytes, the formula is:
+ // (floatSize * vectorSize + 15) * numWords
+ val approxSize = (4L * vectorSize + 15) * numWords
+ val nPartitions = ((approxSize / bufferSize) + 1).toInt
+ val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
+ spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))
+ }
+ }
+
+ @Since("1.4.0")
+ override def load(sc: SparkContext, path: String): Word2VecModel = {
+
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
+ val expectedNumWords = (metadata \ "numWords").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ (loadedClassName, loadedVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path)
+ val vectorSize = model.getVectors.values.head.length
+ val numWords = model.getVectors.size
+ require(expectedVectorSize == vectorSize,
+ s"Word2VecModel requires each word to be mapped to a vector of size " +
+ s"$expectedVectorSize, got vector of size $vectorSize")
+ require(expectedNumWords == numWords,
+ s"Word2VecModel requires $expectedNumWords words, but got $numWords")
+ model
+ case _ => throw new Exception(
+ s"Word2VecModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
new file mode 100644
index 0000000..55685db
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.tree.impl.{DecisionForest, DTUtils}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity._
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * A class which implements a decision tree learning algorithm for classification and regression.
+ * It supports both continuous and categorical features.
+ *
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of decision tree (classification or regression), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ * @param seed Random seed.
+ */
+@Since("1.0.0")
+class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
+ extends Serializable with Logging {
+
+ /**
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of decision tree (classification or regression), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+ @Since("1.0.0")
+ def this(strategy: Strategy) = this(strategy, seed = 0)
+
+ strategy.assertValid()
+
+ /**
+ * Method to train a decision tree model over an RDD
+ *
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return DecisionTreeModel that can be used for prediction.
+ */
+ @Since("1.2.0")
+ def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ val trees = DecisionForest.run(instances, strategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = seed.toLong, None)
+ val rfModel = new RandomForestModel(strategy.algo, trees.map(_.toOld))
+ rfModel.trees(0)
+ }
+}
+
+@Since("1.0.0")
+object DecisionTree extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of decision tree (classification or regression), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ * @return DecisionTreeModel that can be used for prediction.
+ *
+ * @note Using `org.apache.spark.mllib.tree.DecisionTree.trainClassifier`
+ * and `org.apache.spark.mllib.tree.DecisionTree.trainRegressor`
+ * is recommended to clearly separate classification and regression.
+ */
+ @Since("1.0.0")
+ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
+ new DecisionTree(strategy).run(input)
+ }
+
+ /**
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @return DecisionTreeModel that can be used for prediction.
+ *
+ * @note Using `org.apache.spark.mllib.tree.DecisionTree.trainClassifier`
+ * and `org.apache.spark.mllib.tree.DecisionTree.trainRegressor`
+ * is recommended to clearly separate classification and regression.
+ */
+ @Since("1.0.0")
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int): DecisionTreeModel = {
+ val (_, maxMemInMB) = DTUtils.getInvisibleParamsForMLLib(input)
+ val strategy =
+ new Strategy(algo, impurity, maxDepth, maxMemoryInMB = maxMemInMB)
+ new DecisionTree(strategy).run(input)
+ }
+
+ /**
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @param numClasses Number of classes for classification. Default value of 2.
+ * @return DecisionTreeModel that can be used for prediction.
+ *
+ * @note Using `org.apache.spark.mllib.tree.DecisionTree.trainClassifier`
+ * and `org.apache.spark.mllib.tree.DecisionTree.trainRegressor`
+ * is recommended to clearly separate classification and regression.
+ */
+ @Since("1.2.0")
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClasses: Int): DecisionTreeModel = {
+ val (_, maxMemInMB) = DTUtils.getInvisibleParamsForMLLib(input)
+ val strategy = new Strategy(algo, impurity, maxDepth, numClasses,
+ maxMemoryInMB = maxMemInMB)
+ new DecisionTree(strategy).run(input)
+ }
+
+ /**
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @param numClasses Number of classes for classification. Default value of 2.
+ * @param maxBins Maximum number of bins used for splitting features.
+ * @param quantileCalculationStrategy Algorithm for calculating quantiles.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
+ * @return DecisionTreeModel that can be used for prediction.
+ *
+ * @note Using `org.apache.spark.mllib.tree.DecisionTree.trainClassifier`
+ * and `org.apache.spark.mllib.tree.DecisionTree.trainRegressor`
+ * is recommended to clearly separate classification and regression.
+ */
+ @Since("1.0.0")
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClasses: Int,
+ maxBins: Int,
+ quantileCalculationStrategy: QuantileStrategy,
+ categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
+ val (_, maxMemInMB) = DTUtils.getInvisibleParamsForMLLib(input)
+ val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
+ quantileCalculationStrategy, categoricalFeaturesInfo, maxMemoryInMB = maxMemInMB)
+ new DecisionTree(strategy).run(input)
+ }
+
+ /**
+ * Method to train a decision tree model for binary or multiclass classification.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels should take values {0, 1, ..., numClasses-1}.
+ * @param numClasses Number of classes for classification.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
+ * @param impurity Criterion used for information gain calculation.
+ * Supported values: "gini" (recommended) or "entropy".
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 5)
+ * @param maxBins Maximum number of bins used for splitting features.
+ * (suggested value: 32)
+ * @return DecisionTreeModel that can be used for prediction.
+ */
+ @Since("1.1.0")
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numClasses: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ val impurityType = Impurities.fromString(impurity)
+ train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
+ categoricalFeaturesInfo)
+ }
+
+ /**
+ * Java-friendly API for `org.apache.spark.mllib.tree.DecisionTree.trainClassifier`
+ */
+ @Since("1.1.0")
+ def trainClassifier(
+ input: JavaRDD[LabeledPoint],
+ numClasses: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ trainClassifier(input.rdd, numClasses,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ impurity, maxDepth, maxBins)
+ }
+
+ /**
+ * Method to train a decision tree model for regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels are real numbers.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
+ * @param impurity Criterion used for information gain calculation.
+ * The only supported value for regression is "variance".
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 5)
+ * @param maxBins Maximum number of bins used for splitting features.
+ * (suggested value: 32)
+ * @return DecisionTreeModel that can be used for prediction.
+ */
+ @Since("1.1.0")
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ categoricalFeaturesInfo: Map[Int, Int],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ val impurityType = Impurities.fromString(impurity)
+ train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
+ }
+
+ /**
+ * Java-friendly API for `org.apache.spark.mllib.tree.DecisionTree.trainRegressor`
+ */
+ @Since("1.1.0")
+ def trainRegressor(
+ input: JavaRDD[LabeledPoint],
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ trainRegressor(input.rdd,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ impurity, maxDepth, maxBins)
+ }
+}
diff --git a/ml-core/pom.xml b/ml-core/pom.xml
index 443d9b3..38cb635 100644
--- a/ml-core/pom.xml
+++ b/ml-core/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.1.0
+ 2.2.0
4.0.0
boostkit-ml-core_2.12
- 2.1.0
+ 2.2.0
${project.artifactId}
Spark ml core
@@ -15,7 +15,7 @@
org.apache.spark
boostkit-ml-kernel-client-core_2.12
- 2.1.0
+ 2.2.0
${spark.version}
compile
@@ -67,6 +67,25 @@
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+ 3.0.0
+
+
+ generate-sources
+
+ add-source
+
+
+
+ src/main/java
+ src/main/native
+
+
+
+
+
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java
new file mode 100644
index 0000000..e6cc5fe
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java
@@ -0,0 +1,240 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib;
+
+public interface BLAS {
+
+ public static BLAS getInstance() {
+ return InstanceBuilder.BLAS.getInstance();
+ }
+
+ public double dasum(int n, double[] x, int incx);
+ public double dasum(int n, double[] x, int offsetx, int incx);
+
+ public float sasum(int n, float[] x, int incx);
+ public float sasum(int n, float[] x, int offsetx, int incx);
+
+ public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy);
+ public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy);
+ public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public void dcopy(int n, double[] x, int incx, double[] y, int incy);
+ public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void scopy(int n, float[] x, int incx, float[] y, int incy);
+ public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public double ddot(int n, double[] x, int incx, double[] y, int incy);
+ public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public float sdot(int n, float[] x, int incx, float[] y, int incy);
+ public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public float sdsdot(int n, float sb, float[] sx, int incx, float[] sy, int incy);
+ public float sdsdot(int n, float sb, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy);
+
+ public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
+ public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
+ public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc);
+ public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
+ public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
+
+ public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
+ public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
+ public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda);
+ public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda);
+ public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ public double dnrm2(int n, double[] x, int incx);
+ public double dnrm2(int n, double[] x, int offsetx, int incx);
+
+ public float snrm2(int n, float[] x, int incx);
+ public float snrm2(int n, float[] x, int offsetx, int incx);
+
+ public void drot(int n, double[] dx, int incx, double[] dy, int incy, double c, double s);
+ public void drot(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s);
+
+ public void srot(int n, float[] sx, int incx, float[] sy, int incy, float c, float s);
+ public void srot(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s);
+
+ public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s);
+
+ public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s);
+
+ public void drotm(int n, double[] dx, int incx, double[] dy, int incy, double[] dparam);
+ public void drotm(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam);
+
+ public void srotm(int n, float[] sx, int incx, float[] sy, int incy, float[] sparam);
+ public void srotm(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam);
+
+ public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam);
+ public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam);
+
+ public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam);
+ public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam);
+
+ public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
+ public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
+ public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dscal(int n, double alpha, double[] x, int incx);
+ public void dscal(int n, double alpha, double[] x, int offsetx, int incx);
+
+ public void sscal(int n, float alpha, float[] x, int incx);
+ public void sscal(int n, float alpha, float[] x, int offsetx, int incx);
+
+ public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy);
+ public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sspmv(String uplo, int n, float alpha, float[] ap, float[] x, int incx, float beta, float[] y, int incy);
+ public void sspmv(String uplo, int n, float alpha, float[] ap, int offsetap, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a);
+ public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
+
+ public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] ap);
+ public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] ap, int offsetap);
+
+ public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] ap);
+ public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] ap, int offsetap);
+
+ public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] ap);
+ public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] ap, int offsetap);
+
+ public void dswap(int n, double[] x, int incx, double[] y, int incy);
+ public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void sswap(int n, float[] x, int incx, float[] y, int incy);
+ public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc);
+ public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc);
+
+ public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
+ public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
+
+ public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
+ public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
+ public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda);
+ public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
+
+ public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda);
+ public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
+
+ public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda);
+ public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda);
+ public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc);
+ public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc);
+
+ public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
+ public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
+
+ public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int Ldc);
+ public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int Ldc);
+
+ public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int Ldc);
+ public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int Ldc);
+
+ public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx);
+ public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx);
+ public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx);
+ public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx);
+ public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx);
+ public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx);
+
+ public void stpmv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx);
+ public void stpmv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx);
+
+ public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx);
+ public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx);
+
+ public void stpsv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx);
+ public void stpsv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx);
+
+ public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb);
+ public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb);
+ public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx);
+ public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx);
+ public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb);
+ public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb);
+ public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx);
+ public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx);
+ public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public int idamax(int n, double[] x, int incx);
+ public int idamax(int n, double[] x, int offsetx, int incx);
+
+ public int isamax(int n, float[] sx, int incx);
+ public int isamax(int n, float[] sx, int offsetsx, int incx);
+
+ public boolean lsame(String ca, String cb);
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java b/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java
new file mode 100644
index 0000000..0d3eee6
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib;
+
+import java.util.logging.Logger;
+
+final class InstanceBuilder {
+
+ public static final class BLAS {
+ private static final dev.ludovic.netlib.BLAS instance = getInstanceImpl();
+
+ public static dev.ludovic.netlib.BLAS getInstance() {
+ return instance;
+ }
+
+ private static dev.ludovic.netlib.BLAS getInstanceImpl() {
+ try {
+ return dev.ludovic.netlib.NativeBLAS.getInstance();
+ } catch (Throwable t) {
+ Logger.getLogger(BLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.NativeBLAS.class.getName());
+ }
+ return dev.ludovic.netlib.JavaBLAS.getInstance();
+ }
+ }
+
+ public static final class NativeBLAS {
+ private static final dev.ludovic.netlib.NativeBLAS instance = getInstanceImpl();
+
+ public static dev.ludovic.netlib.NativeBLAS getInstance() {
+ return instance;
+ }
+
+ private static dev.ludovic.netlib.NativeBLAS getInstanceImpl() {
+ try {
+ return dev.ludovic.netlib.blas.JNIBLAS.getInstance();
+ } catch (Throwable t) {
+ Logger.getLogger(NativeBLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.blas.JNIBLAS.class.getName());
+ }
+ throw new RuntimeException("Unable to load native implementation");
+ }
+ }
+
+ public static final class JavaBLAS {
+ private static final dev.ludovic.netlib.JavaBLAS instance = getInstanceImpl();
+
+ public static dev.ludovic.netlib.JavaBLAS getInstance() {
+ return instance;
+ }
+
+ private static dev.ludovic.netlib.JavaBLAS getInstanceImpl() {
+ return dev.ludovic.netlib.blas.Java8BLAS.getInstance();
+ }
+ }
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java
new file mode 100644
index 0000000..834aa39
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib;
+
+public interface JavaBLAS extends BLAS {
+
+ public static JavaBLAS getInstance() {
+ return InstanceBuilder.JavaBLAS.getInstance();
+ }
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java
new file mode 100644
index 0000000..a6fd83a
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib;
+
+public interface NativeBLAS extends BLAS {
+
+ public static NativeBLAS getInstance() {
+ return InstanceBuilder.NativeBLAS.getInstance();
+ }
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java
new file mode 100644
index 0000000..00f3c64
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java
@@ -0,0 +1,1689 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib.blas;
+
+import java.util.Objects;
+
+import dev.ludovic.netlib.BLAS;
+
+abstract class AbstractBLAS implements BLAS {
+
+ private final static boolean debug = System.getProperty("dev.ludovic.netlib.blas.debug", "false").equals("true");
+
+ protected int loopAlign(int index, int max, int size) {
+ return Math.min(loopBound(index + size - 1, size), max);
+ }
+
+ protected int loopBound(int index, int size) {
+ return index - (index % size);
+ }
+
+ private void checkArgument(String method, int arg, boolean check) {
+ if (!check) {
+ throw new IllegalArgumentException(String.format("** On entry to '%s' parameter number %d had an illegal value", method, arg));
+ }
+ }
+
+ private void checkIndex(int index, int length) {
+ //FIXME: switch to Objects.checkIndex when the minimum version becomes JDK 11
+ if (index < 0 || index >= length) {
+ throw new IndexOutOfBoundsException(String.format("Index %s out of bounds for length %s", index, length));
+ }
+ }
+
+ private void requireNonNull(T obj) {
+ Objects.requireNonNull(obj);
+ }
+
+ public double dasum(int n, double[] x, int incx) {
+ if (debug) System.err.println("dasum");
+ return dasum(n, x, 0, incx);
+ }
+
+ public double dasum(int n, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dasum");
+ if (n <= 0) {
+ return 0.0;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ return dasumK(n, x, offsetx, incx);
+ }
+
+ protected abstract double dasumK(int n, double[] x, int offsetx, int incx);
+
+ public float sasum(int n, float[] x, int incx) {
+ if (debug) System.err.println("sasum");
+ return sasum(n, x, 0, incx);
+ }
+
+ public float sasum(int n, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("sasum");
+ if (n <= 0) {
+ return 0.0f;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ return sasumK(n, x, offsetx, incx);
+ }
+
+ protected abstract float sasumK(int n, float[] x, int offsetx, int incx);
+
+ public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy) {
+ if (debug) System.err.println("daxpy");
+ daxpy(n, alpha, x, 0, incx, y, 0, incy);
+ }
+
+ // y += alpha * x
+ public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("daxpy");
+ if (n <= 0) {
+ return;
+ }
+ if (alpha == 0.0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ daxpyK(n, alpha, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy) {
+ if (debug) System.err.println("saxpy");
+ saxpy(n, alpha, x, 0, incx, y, 0, incy);
+ }
+
+ // y += alpha * x
+ public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("saxpy");
+ if (n <= 0) {
+ return;
+ }
+ if (alpha == 0.0f) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ saxpyK(n, alpha, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public void dcopy(int n, double[] x, int incx, double[] y, int incy) {
+ if (debug) System.err.println("dcopy");
+ dcopy(n, x, 0, incx, y, 0, incy);
+ }
+
+ public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dcopy");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ dcopyK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void scopy(int n, float[] x, int incx, float[] y, int incy) {
+ if (debug) System.err.println("scopy");
+ scopy(n, x, 0, incx, y, 0, incy);
+ }
+
+ public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("scopy");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ scopyK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public double ddot(int n, double[] x, int incx, double[] y, int incy) {
+ if (debug) System.err.println("ddot");
+ return ddot(n, x, 0, incx, y, 0, incy);
+ }
+
+ // sum(x * y)
+ public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("ddot");
+ if (n <= 0) {
+ return 0.0;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ return ddotK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public float sdot(int n, float[] x, int incx, float[] y, int incy) {
+ if (debug) System.err.println("sdot");
+ return sdot(n, x, 0, incx, y, 0, incy);
+ }
+
+ // sum(x * y)
+ public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sdot");
+ if (n <= 0) {
+ return 0.0f;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ return sdotK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public float sdsdot(int n, float sb, float[] x, int incx, float[] y, int incy) {
+ if (debug) System.err.println("sdsdot");
+ return sdsdot(n, sb, x, 0, incx, y, 0, incy);
+ }
+
+ public float sdsdot(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sdsdot");
+ if (n <= 0) {
+ return 0.0f;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ return sdsdotK(n, sb, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
+ if (debug) System.err.println("dgbmv");
+ dgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dgbmv");
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
+ dgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
+ if (debug) System.err.println("sgbmv");
+ sgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sgbmv");
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
+ sgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
+ if (debug) System.err.println("dgemm");
+ dgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ // c = alpha * a * b + beta * c
+ public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("dgemm");
+ checkArgument("DGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa));
+ checkArgument("DGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb));
+ checkArgument("DGEMM", 3, m >= 0);
+ checkArgument("DGEMM", 4, n >= 0);
+ checkArgument("DGEMM", 5, k >= 0);
+ checkArgument("DGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k));
+ checkArgument("DGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n));
+ checkArgument("DGEMM", 13, ldc >= Math.max(1, m));
+ if (m == 0 || n == 0 || ((alpha == 0.0 || k == 0) && beta == 1.0)) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
+ checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
+ checkIndex(offsetc + m * n - 1, c.length);
+ dgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
+ if (debug) System.err.println("sgemm");
+ sgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ // c = alpha * a * b + beta * c
+ public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("sgemm");
+ checkArgument("SGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa));
+ checkArgument("SGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb));
+ checkArgument("SGEMM", 3, m >= 0);
+ checkArgument("SGEMM", 4, n >= 0);
+ checkArgument("SGEMM", 5, k >= 0);
+ checkArgument("SGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k));
+ checkArgument("SGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n));
+ checkArgument("SGEMM", 13, ldc >= Math.max(1, m));
+ if (m == 0 || n == 0 || ((alpha == 0.0f || k == 0) && beta == 1.0f)) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
+ checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
+ checkIndex(offsetc + m * n - 1, c.length);
+ sgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
+ if (debug) System.err.println("dgemv");
+ dgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ // y = alpha * A * x + beta * y
+ public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dgemv");
+ checkArgument("DGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DGEMV", 2, m >= 0);
+ checkArgument("DGEMV", 3, n >= 0);
+ checkArgument("DGEMV", 6, lda >= Math.max(1, m));
+ checkArgument("DGEMV", 8, incx != 0);
+ checkArgument("DGEMV", 11, incy != 0);
+ if (m == 0 || n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
+ dgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
+ if (debug) System.err.println("sgemv");
+ sgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ // y = alpha * A * x + beta * y
+ public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sgemv");
+ checkArgument("SGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("SGEMV", 2, m >= 0);
+ checkArgument("SGEMV", 3, n >= 0);
+ checkArgument("SGEMV", 6, lda >= Math.max(1, m));
+ checkArgument("SGEMV", 8, incx != 0);
+ checkArgument("SGEMV", 11, incy != 0);
+ if (m == 0 || n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
+ sgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ // A += alpha * x * y.t
+ public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) {
+ if (debug) System.err.println("dger");
+ dger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
+ }
+
+ public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ if (debug) System.err.println("dger");
+ checkArgument("DGER", 1, m >= 0);
+ checkArgument("DGER", 2, n >= 0);
+ checkArgument("DGER", 5, incx != 0);
+ checkArgument("DGER", 7, incy != 0);
+ checkArgument("DGER", 9, lda >= Math.max(1, m));
+ if (m == 0 || n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ if (alpha != 0.0) {
+ dgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ }
+
+ protected abstract void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) {
+ if (debug) System.err.println("sger");
+ sger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
+ }
+
+ public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ if (debug) System.err.println("sger");
+ checkArgument("SGER", 1, m >= 0);
+ checkArgument("SGER", 2, n >= 0);
+ checkArgument("SGER", 5, incx != 0);
+ checkArgument("SGER", 7, incy != 0);
+ checkArgument("SGER", 9, lda >= Math.max(1, m));
+ if (m == 0 || n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ if (alpha != 0.0f) {
+ sgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ }
+
+ protected abstract void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ public double dnrm2(int n, double[] x, int incx) {
+ if (debug) System.err.println("dnrm2");
+ return dnrm2(n, x, 0, incx);
+ }
+
+ public double dnrm2(int n, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dnrm2");
+ if (n <= 0) {
+ return 0.0;
+ }
+ if (incx <= 0) {
+ return 0.0;
+ }
+ if (n == 1) {
+ return Math.abs(x[offsetx + 0]);
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ return dnrm2K(n, x, offsetx, incx);
+ }
+
+ protected abstract double dnrm2K(int n, double[] x, int offsetx, int incx);
+
+ public float snrm2(int n, float[] x, int incx) {
+ if (debug) System.err.println("snrm2");
+ return snrm2(n, x, 0, incx);
+ }
+
+ public float snrm2(int n, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("snrm2");
+ if (n <= 0) {
+ return 0.0f;
+ }
+ if (incx <= 0) {
+ return 0.0f;
+ }
+ if (n == 1) {
+ return Math.abs(x[offsetx + 0]);
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ return snrm2K(n, x, offsetx, incx);
+ }
+
+ protected abstract float snrm2K(int n, float[] x, int offsetx, int incx);
+
+ public void drot(int n, double[] x, int incx, double[] y, int incy, double c, double s) {
+ if (debug) System.err.println("drot");
+ drot(n, x, 0, incx, y, 0, incy, c, s);
+ }
+
+ public void drot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
+ if (debug) System.err.println("drot");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ drotK(n, x, offsetx, incx, y, offsety, incy, c, s);
+ }
+
+ protected abstract void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s);
+
+ public void srot(int n, float[] x, int incx, float[] y, int incy, float c, float s) {
+ if (debug) System.err.println("srot");
+ srot(n, x, 0, incx, y, 0, incy, c, s);
+ }
+
+ public void srot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
+ if (debug) System.err.println("srot");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ srotK(n, x, offsetx, incx, y, offsety, incy, c, s);
+ }
+
+ protected abstract void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s);
+
+ public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s) {
+ if (debug) System.err.println("drotg");
+ double scale = Math.abs(da.val) + Math.abs(db.val);
+ if (scale == 0.0) {
+ c.val = 1.0;
+ s.val = 0.0;
+ da.val = 0.0;
+ db.val = 0.0;
+ } else {
+ double r = scale * Math.sqrt(Math.pow(da.val / scale, 2) + Math.pow(db.val / scale, 2))
+ * ((Math.abs(da.val) > Math.abs(db.val) ? da.val : db.val) >= 0.0 ? 1.0 : -1.0);
+ c.val = da.val / r;
+ s.val = db.val / r;
+ double z = 1.0;
+ if (Math.abs(da.val) > Math.abs(db.val)) {
+ z = s.val;
+ } else if (c.val != 0.0) {
+ z = 1.0 / c.val;
+ }
+ da.val = r;
+ db.val = z;
+ }
+ }
+
+ public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s) {
+ if (debug) System.err.println("srotg");
+ float scale = Math.abs(sa.val) + Math.abs(sb.val);
+ if (scale == 0.0f) {
+ c.val = 1.0f;
+ s.val = 0.0f;
+ sa.val = 0.0f;
+ sb.val = 0.0f;
+ } else {
+ float r = (float)(scale * Math.sqrt(Math.pow(sa.val / scale, 2) + Math.pow(sb.val / scale, 2))
+ * ((Math.abs(sa.val) > Math.abs(sb.val) ? sa.val : sb.val) >= 0.0f ? 1.0 : -1.0));
+ c.val = sa.val / r;
+ s.val = sb.val / r;
+ float z = 1.0f;
+ if (Math.abs(sa.val) > Math.abs(sb.val)) {
+ z = s.val;
+ } else if (c.val != 0.0f) {
+ z = 1.0f / c.val;
+ }
+ sa.val = r;
+ sb.val = z;
+ }
+ }
+
+ public void drotm(int n, double[] x, int incx, double[] y, int incy, double[] param) {
+ if (debug) System.err.println("drotm");
+ drotm(n, x, 0, incx, y, 0, incy, param, 0);
+ }
+
+ public void drotm(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
+ if (debug) System.err.println("drotm");
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(param);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offsetparam + 4, param.length); /* param.length == 5 */
+ drotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+
+ protected abstract void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam);
+
+ public void srotm(int n, float[] x, int incx, float[] y, int incy, float[] param) {
+ if (debug) System.err.println("srotm");
+ srotm(n, x, 0, incx, y, 0, incy, param, 0);
+ }
+
+ public void srotm(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
+ if (debug) System.err.println("srotm");
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(param);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offsetparam + 4, param.length); /* param.length == 5 */
+ srotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+
+ protected abstract void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam);
+
+ public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param) {
+ if (debug) System.err.println("drotmg");
+ drotmg(dd1, dd2, dx1, dy1, param, 0);
+ }
+
+ public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
+ if (debug) System.err.println("drotmg");
+ requireNonNull(dd1);
+ requireNonNull(dd2);
+ requireNonNull(dx1);
+ requireNonNull(param);
+ checkIndex(offsetparam + 4, param.length);
+ drotmgK(dd1, dd2, dx1, dy1, param, offsetparam);
+ }
+
+ protected abstract void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam);
+
+ public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param) {
+ if (debug) System.err.println("srotmg");
+ srotmg(sd1, sd2, sx1, sy1, param, 0);
+ }
+
+ public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
+ if (debug) System.err.println("srotmg");
+ requireNonNull(sd1);
+ requireNonNull(sd2);
+ requireNonNull(sx1);
+ requireNonNull(param);
+ checkIndex(offsetparam + 4, param.length);
+ srotmgK(sd1, sd2, sx1, sy1, param, offsetparam);
+ }
+
+ protected abstract void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam);
+
+ public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
+ if (debug) System.err.println("dsbmv");
+ dsbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dsbmv");
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ dsbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
+ if (debug) System.err.println("ssbmv");
+ ssbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("ssbmv");
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ ssbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dscal(int n, double alpha, double[] x, int incx) {
+ if (debug) System.err.println("dscal");
+ dscal(n, alpha, x, 0, incx);
+ }
+
+ // x = alpha * x
+ public void dscal(int n, double alpha, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dscal");
+ if (n <= 0) {
+ return;
+ }
+ if (incx <= 0) {
+ return;
+ }
+ if (alpha == 1.0) {
+ return;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dscalK(n, alpha, x, offsetx, incx);
+ }
+
+ protected abstract void dscalK(int n, double alpha, double[] x, int offsetx, int incx);
+
+ public void sscal(int n, float alpha, float[] x, int incx) {
+ if (debug) System.err.println("sscal");
+ sscal(n, alpha, x, 0, incx);
+ }
+
+ // x = alpha * x
+ public void sscal(int n, float alpha, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("sscal");
+ if (n <= 0) {
+ return;
+ }
+ if (incx <= 0) {
+ return;
+ }
+ if (alpha == 1.0f) {
+ return;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ sscalK(n, alpha, x, offsetx, incx);
+ }
+
+ protected abstract void sscalK(int n, float alpha, float[] x, int offsetx, int incx);
+
+ public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy) {
+ if (debug) System.err.println("dspmv");
+ dspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy);
+ }
+
+ // y = alpha * a * x + beta * y
+ public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dspmv");
+ checkArgument("DSPMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSPMV", 2, n >= 0);
+ checkArgument("DSPMV", 6, incx != 0);
+ checkArgument("DSPMV", 9, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ dspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void sspmv(String uplo, int n, float alpha, float[] a, float[] x, int incx, float beta, float[] y, int incy) {
+ if (debug) System.err.println("sspmv");
+ sspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void sspmv(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sspmv");
+ checkArgument("SSPMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSPMV", 2, n >= 0);
+ checkArgument("SSPMV", 6, incx != 0);
+ checkArgument("SSPMV", 9, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ sspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a) {
+ if (debug) System.err.println("dspr");
+ dspr(uplo, n, alpha, x, 0, incx, a, 0);
+ }
+
+ // a += alpha * x * x.t
+ public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
+ if (debug) System.err.println("dspr");
+ checkArgument("DSPR", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSPR", 2, n >= 0);
+ checkArgument("DSPR", 5, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ dsprK(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+
+ protected abstract void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
+
+ public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] a) {
+ if (debug) System.err.println("sspr");
+ sspr(uplo, n, alpha, x, 0, incx, a, 0);
+ }
+
+ // a += alpha * x * x.t
+ public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
+ if (debug) System.err.println("sspr");
+ checkArgument("SSPR", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSPR", 2, n >= 0);
+ checkArgument("SSPR", 5, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ ssprK(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+
+ protected abstract void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta);
+
+ public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a) {
+ if (debug) System.err.println("dspr2");
+ dspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0);
+ }
+
+ // a += alpha * x * y.t + alpha * y * x.t
+ public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
+ if (debug) System.err.println("dspr2");
+ checkArgument("DSPR2", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSPR2", 2, n >= 0);
+ checkArgument("DSPR2", 5, incx != 0);
+ checkArgument("DSPR2", 7, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ dspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+
+ protected abstract void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta);
+
+ public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a) {
+ if (debug) System.err.println("sspr2");
+ sspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0);
+ }
+
+ // a += alpha * x * y.t + alpha * y * x.t
+ public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
+ if (debug) System.err.println("sspr2");
+ checkArgument("SSPR2", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSPR2", 2, n >= 0);
+ checkArgument("SSPR2", 5, incx != 0);
+ checkArgument("SSPR2", 7, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
+ sspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+
+ protected abstract void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta);
+
+ public void dswap(int n, double[] x, int incx, double[] y, int incy) {
+ if (debug) System.err.println("dswap");
+ dswap(n, x, 0, incx, y, 0, incy);
+ }
+
+ public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dswap");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ dswapK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ public void sswap(int n, float[] x, int incx, float[] y, int incy) {
+ if (debug) System.err.println("sswap");
+ sswap(n, x, 0, incx, y, 0, incy);
+ }
+
+ public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("sswap");
+ if (n <= 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ sswapK(n, x, offsetx, incx, y, offsety, incy);
+ }
+
+ protected abstract void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
+ if (debug) System.err.println("dsymm");
+ dsymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("dsymm");
+ checkArgument("DSYMM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("DSYMM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYMM", 3, m >= 0);
+ checkArgument("DSYMM", 4, n >= 0);
+ checkArgument("DSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("DSYMM", 9, ldb >= Math.max(1, m));
+ checkArgument("DSYMM", 12, ldc >= Math.max(1, m));
+ if (m == 0 || n == 0 || (alpha == 0.0 && beta == 1.0)) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ dsymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
+ if (debug) System.err.println("ssymm");
+ ssymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("ssymm");
+ checkArgument("SSYMM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("SSYMM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYMM", 3, m >= 0);
+ checkArgument("SSYMM", 4, n >= 0);
+ checkArgument("SSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("SSYMM", 9, ldb >= Math.max(1, m));
+ checkArgument("SSYMM", 12, ldc >= Math.max(1, m));
+ if (m == 0 || n == 0 || (alpha == 0.0f && beta == 1.0f)) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ ssymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
+ if (debug) System.err.println("dsymv");
+ dsymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (debug) System.err.println("dsymv");
+ checkArgument("DSYMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYMV", 2, n >= 0);
+ checkArgument("DSYMV", 5, lda >= Math.max(1, n));
+ checkArgument("DSYMV", 7, incx != 0);
+ checkArgument("DSYMV", 10, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ dsymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
+ if (debug) System.err.println("ssymv");
+ ssymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
+ }
+
+ public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (debug) System.err.println("ssymv");
+ checkArgument("SSYMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYMV", 2, n >= 0);
+ checkArgument("SSYMV", 5, lda >= Math.max(1, n));
+ checkArgument("SSYMV", 7, incx != 0);
+ checkArgument("SSYMV", 10, incy != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ requireNonNull(y);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ ssymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected abstract void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda) {
+ if (debug) System.err.println("dsyr");
+ dsyr(uplo, n, alpha, x, 0, incx, a, 0, lda);
+ }
+
+ // a += alpha * x * x.t
+ public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
+ if (debug) System.err.println("dsyr");
+ checkArgument("DSYR", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYR", 2, n >= 0);
+ checkArgument("DSYR", 5, incx != 0);
+ checkArgument("DSYR", 7, lda >= Math.max(1, n));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ dsyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+
+ protected abstract void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
+
+ public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda) {
+ if (debug) System.err.println("ssyr");
+ ssyr(uplo, n, alpha, x, 0, incx, a, 0, lda);
+ }
+
+ public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
+ if (debug) System.err.println("ssyr");
+ checkArgument("SSYR", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYR", 2, n >= 0);
+ checkArgument("SSYR", 5, incx != 0);
+ checkArgument("SSYR", 7, lda >= Math.max(1, n));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ ssyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+
+ protected abstract void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
+
+ public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) {
+ if (debug) System.err.println("dsyr2");
+ dsyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
+ }
+
+ public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ if (debug) System.err.println("dsyr2");
+ checkArgument("DSYR2", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYR2", 2, n >= 0);
+ checkArgument("DSYR2", 5, incx != 0);
+ checkArgument("DSYR2", 7, incy != 0);
+ checkArgument("DSYR2", 9, lda >= Math.max(1, n));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ dsyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+
+ protected abstract void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) {
+ if (debug) System.err.println("ssyr2");
+ ssyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
+ }
+
+ public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ if (debug) System.err.println("ssyr2");
+ checkArgument("SSYR2", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYR2", 2, n >= 0);
+ checkArgument("SSYR2", 5, incx != 0);
+ checkArgument("SSYR2", 7, incy != 0);
+ checkArgument("SSYR2", 9, lda >= Math.max(1, n));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(x);
+ requireNonNull(y);
+ requireNonNull(a);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
+ checkIndex(offseta + n * lda - 1, a.length);
+ ssyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+
+ protected abstract void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
+ if (debug) System.err.println("dsyr2k");
+ dsyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("dsyr2k");
+ checkArgument("DSYR2K", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DSYR2K", 3, n >= 0);
+ checkArgument("DSYR2K", 4, k >= 0);
+ checkArgument("DSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("DSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("DSYR2K", 12, ldc >= Math.max(1, n));
+ if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0))
+ return;
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
+ checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ dsyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
+ if (debug) System.err.println("ssyr2k");
+ ssyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
+ }
+
+ public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("ssyr2k");
+ checkArgument("SSYR2K", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("SSYR2K", 3, n >= 0);
+ checkArgument("SSYR2K", 4, k >= 0);
+ checkArgument("SSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("SSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("SSYR2K", 12, ldc >= Math.max(1, n));
+ if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f))
+ return;
+ requireNonNull(a);
+ requireNonNull(b);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
+ checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ ssyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int ldc) {
+ if (debug) System.err.println("dsyrk");
+ dsyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc);
+ }
+
+ public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("dsyrk");
+ checkArgument("DSYRK", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DSYRK", 3, n >= 0);
+ checkArgument("DSYRK", 4, k >= 0);
+ checkArgument("DSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("DSYRK", 10, ldc >= Math.max(1, n));
+ if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0))
+ return;
+ requireNonNull(a);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ dsyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc);
+
+ public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int ldc) {
+ if (debug) System.err.println("ssyrk");
+ ssyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc);
+ }
+
+ public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
+ if (debug) System.err.println("ssyrk");
+ checkArgument("SSYRK", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("SSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("SSYRK", 3, n >= 0);
+ checkArgument("SSYRK", 4, k >= 0);
+ checkArgument("SSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
+ checkArgument("SSYRK", 10, ldc >= Math.max(1, n));
+ if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f))
+ return;
+ requireNonNull(a);
+ requireNonNull(c);
+ checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
+ checkIndex(offsetc + n * ldc - 1, c.length);
+ ssyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+
+ protected abstract void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc);
+
+ public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) {
+ if (debug) System.err.println("dtbmv");
+ dtbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
+ }
+
+ public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtbmv");
+ checkArgument("DTBMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTBMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTBMV", 4, n >= 0);
+ checkArgument("DTBMV", 5, k >= 0);
+ checkArgument("DTBMV", 7, lda >= Math.max(1, k));
+ checkArgument("DTBMV", 9, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) {
+ if (debug) System.err.println("stbmv");
+ stbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
+ }
+
+ public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("stbmv");
+ checkArgument("STBMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STBMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STBMV", 4, n >= 0);
+ checkArgument("STBMV", 5, k >= 0);
+ checkArgument("STBMV", 7, lda >= Math.max(1, k));
+ checkArgument("STBMV", 9, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ stbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) {
+ if (debug) System.err.println("dtbsv");
+ dtbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
+ }
+
+ public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtbsv");
+ checkArgument("DTBSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTBSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTBSV", 4, n >= 0);
+ checkArgument("DTBSV", 5, k >= 0);
+ checkArgument("DTBSV", 7, lda >= Math.max(1, k));
+ checkArgument("DTBSV", 9, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) {
+ if (debug) System.err.println("stbsv");
+ stbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
+ }
+
+ public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("stbsv");
+ checkArgument("STBSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STBSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STBSV", 4, n >= 0);
+ checkArgument("STBSV", 5, k >= 0);
+ checkArgument("STBSV", 7, lda >= Math.max(1, k));
+ checkArgument("STBSV", 9, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ stbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtpmv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) {
+ if (debug) System.err.println("dtpmv");
+ dtpmv(uplo, trans, diag, n, a, 0, x, 0, incx);
+ }
+
+ public void dtpmv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtpmv");
+ checkArgument("DTPMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTPMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTPMV", 4, n >= 0);
+ checkArgument("DTPMV", 7, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected abstract void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
+
+ public void stpmv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) {
+ if (debug) System.err.println("stpmv");
+ stpmv(uplo, trans, diag, n, a, 0, x, 0, incx);
+ }
+
+ public void stpmv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("stpmv");
+ checkArgument("STPMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STPMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STPMV", 4, n >= 0);
+ checkArgument("STPMV", 7, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ stpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected abstract void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
+
+ public void dtpsv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) {
+ if (debug) System.err.println("dtpsv");
+ dtpsv(uplo, trans, diag, n, a, 0, x, 0, incx);
+ }
+
+ public void dtpsv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtpsv");
+ checkArgument("DTPSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTPSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTPSV", 4, n >= 0);
+ checkArgument("DTPSV", 7, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected abstract void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
+
+ public void stpsv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) {
+ if (debug) System.err.println("stpsv");
+ stpsv(uplo, trans, diag, n, a, 0, x, 0, incx);
+ }
+
+ public void stpsv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("stpsv");
+ checkArgument("STPSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STPSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STPSV", 4, n >= 0);
+ checkArgument("STPSV", 7, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ stpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected abstract void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
+
+ public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) {
+ if (debug) System.err.println("dtrmm");
+ dtrmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
+ }
+
+ public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ if (debug) System.err.println("dtrmm");
+ checkArgument("DTRMM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("DTRMM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
+ checkArgument("DTRMM", 4, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTRMM", 5, m >= 0);
+ checkArgument("DTRMM", 6, n >= 0);
+ checkArgument("DTRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("DTRMM", 11, ldb >= Math.max(1, m));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ dtrmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected abstract void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) {
+ if (debug) System.err.println("strmm");
+ strmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
+ }
+
+ public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ if (debug) System.err.println("strmm");
+ checkArgument("STRMM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("STRMM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
+ checkArgument("STRMM", 4, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STRMM", 5, m >= 0);
+ checkArgument("STRMM", 6, n >= 0);
+ checkArgument("STRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("STRMM", 11, ldb >= Math.max(1, m));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ strmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected abstract void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) {
+ if (debug) System.err.println("dtrmv");
+ dtrmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
+ }
+
+ public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtrmv");
+ checkArgument("DTRMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTRMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTRMV", 4, n >= 0);
+ checkArgument("DTRMV", 6, lda >= Math.max(1, n));
+ checkArgument("DTRMV", 8, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtrmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) {
+ if (debug) System.err.println("strmv");
+ strmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
+ }
+
+ public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("strmv");
+ checkArgument("STRMV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STRMV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STRMV", 4, n >= 0);
+ checkArgument("STRMV", 6, lda >= Math.max(1, n));
+ checkArgument("STRMV", 8, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ strmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) {
+ if (debug) System.err.println("dtrsm");
+ dtrsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
+ }
+
+ public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ if (debug) System.err.println("dtrsm");
+ checkArgument("DTRSM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("DTRSM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
+ checkArgument("DTRSM", 4, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTRSM", 5, m >= 0);
+ checkArgument("DTRSM", 6, n >= 0);
+ checkArgument("DTRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("DTRSM", 11, ldb >= Math.max(1, m));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ dtrsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected abstract void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) {
+ if (debug) System.err.println("strsm");
+ strsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
+ }
+
+ public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ if (debug) System.err.println("strsm");
+ checkArgument("STRSM", 1, lsame("L", side) || lsame("R", side));
+ checkArgument("STRSM", 2, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
+ checkArgument("STRSM", 4, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STRSM", 5, m >= 0);
+ checkArgument("STRSM", 6, n >= 0);
+ checkArgument("STRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
+ checkArgument("STRSM", 11, ldb >= Math.max(1, m));
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(b);
+ checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
+ checkIndex(offsetb + n * ldb - 1, b.length);
+ strsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected abstract void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) {
+ if (debug) System.err.println("dtrsv");
+ dtrsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
+ }
+
+ public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("dtrsv");
+ checkArgument("DTRSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("DTRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("DTRSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("DTRSV", 4, n >= 0);
+ checkArgument("DTRSV", 6, lda >= Math.max(1, n));
+ checkArgument("DTRSV", 8, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ dtrsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) {
+ if (debug) System.err.println("strsv");
+ strsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
+ }
+
+ public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("strsv");
+ checkArgument("STRSV", 1, lsame("U", uplo) || lsame("L", uplo));
+ checkArgument("STRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
+ checkArgument("STRSV", 3, lsame("U", diag) || lsame("N", diag));
+ checkArgument("STRSV", 4, n >= 0);
+ checkArgument("STRSV", 6, lda >= Math.max(1, n));
+ checkArgument("STRSV", 8, incx != 0);
+ if (n == 0) {
+ return;
+ }
+ requireNonNull(a);
+ requireNonNull(x);
+ checkIndex(offseta + n * lda - 1, a.length);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ strsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected abstract void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ public int idamax(int n, double[] x, int incx) {
+ if (debug) System.err.println("idamax");
+ return idamax(n, x, 0, incx);
+ }
+
+ public int idamax(int n, double[] x, int offsetx, int incx) {
+ if (debug) System.err.println("idamax");
+ if (n <= 0) {
+ return -1;
+ }
+ if (incx <= 0) {
+ return -1;
+ }
+ if (n == 1) {
+ return 0;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ // Fortran arrays use 1-based index
+ return idamaxK(n, x, offsetx, incx) - 1;
+ }
+
+ protected abstract int idamaxK(int n, double[] x, int offsetx, int incx);
+
+ public int isamax(int n, float[] x, int incx) {
+ if (debug) System.err.println("isamax");
+ return isamax(n, x, 0, incx);
+ }
+
+ public int isamax(int n, float[] x, int offsetx, int incx) {
+ if (debug) System.err.println("isamax");
+ if (n <= 0) {
+ return -1;
+ }
+ if (incx <= 0) {
+ return -1;
+ }
+ if (n == 1) {
+ return 0;
+ }
+ requireNonNull(x);
+ checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
+ // Fortran arrays use 1-based index
+ return isamaxK(n, x, offsetx, incx) - 1;
+ }
+
+ protected abstract int isamaxK(int n, float[] x, int offsetx, int incx);
+
+ public boolean lsame(String ca, String cb) {
+ if (debug) System.err.println("lsame");
+ return ca != null && ca.regionMatches(true, 0, cb, 0, ca.length());
+ }
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java
new file mode 100644
index 0000000..d0a8dab
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java
@@ -0,0 +1,241 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib.blas;
+
+import dev.ludovic.netlib.BLAS;
+
+public final class F2jBLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS {
+
+ private static final F2jBLAS instance = new F2jBLAS();
+
+ protected F2jBLAS() {}
+
+ public static dev.ludovic.netlib.JavaBLAS getInstance() {
+ return instance;
+ }
+
+ protected double dasumK(int n, double[] x, int offsetx, int incx) {
+ return org.netlib.blas.Dasum.dasum(n, x, offsetx, incx);
+ }
+ protected float sasumK(int n, float[] x, int offsetx, int incx) {
+ return org.netlib.blas.Sasum.sasum(n, x, offsetx, incx);
+ }
+ protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ org.netlib.blas.Daxpy.daxpy(n, alpha, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ org.netlib.blas.Saxpy.saxpy(n, alpha, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dcopy.dcopy(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ org.netlib.blas.Scopy.scopy(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ return org.netlib.blas.Ddot.ddot(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ return org.netlib.blas.Sdot.sdot(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ return org.netlib.blas.Sdsdot.sdsdot(n, sb, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dgemm.dgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Sgemm.sgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dgemv.dgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Sgemv.sgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ org.netlib.blas.Dger.dger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ org.netlib.blas.Sger.sger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ protected double dnrm2K(int n, double[] x, int offsetx, int incx) {
+ return org.netlib.blas.Dnrm2.dnrm2(n, x, offsetx, incx);
+ }
+ protected float snrm2K(int n, float[] x, int offsetx, int incx) {
+ return org.netlib.blas.Snrm2.snrm2(n, x, offsetx, incx);
+ }
+ protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
+ org.netlib.blas.Drot.drot(n, x, offsetx, incx, y, offsety, incy, c, s);
+ }
+ protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
+ org.netlib.blas.Srot.srot(n, x, offsetx, incx, y, offsety, incy, c, s);
+ }
+ protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
+ org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+ protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
+ org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+ protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
+ org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam);
+ }
+ protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
+ org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam);
+ }
+ protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dscal.dscal(n, alpha, x, offsetx, incx);
+ }
+ protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Sscal.sscal(n, alpha, x, offsetx, incx);
+ }
+ protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dspmv.dspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Sspmv.sspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
+ org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+ protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
+ org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+ protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
+ org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+ protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
+ org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+ protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dswap.dswap(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ org.netlib.blas.Sswap.sswap(n, x, offsetx, incx, y, offsety, incy);
+ }
+ protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dsymm.dsymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Ssymm.ssymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dsymv.dsymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Ssymv.ssymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
+ org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+ protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
+ org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+ protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+ protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+ protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+ protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+ protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+ protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+ protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+ protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+ protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+ protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+ protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+ protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+ protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+ protected int idamaxK(int n, double[] x, int offsetx, int incx) {
+ return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx);
+ }
+ protected int isamaxK(int n, float[] x, int offsetx, int incx) {
+ return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx);
+ }
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java
new file mode 100644
index 0000000..3437885
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java
@@ -0,0 +1,201 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib.blas;
+
+import java.io.InputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.nio.file.attribute.PosixFilePermissions;
+
+public final class JNIBLAS extends AbstractBLAS implements dev.ludovic.netlib.NativeBLAS {
+
+ private static final JNIBLAS instance = new JNIBLAS();
+
+ protected JNIBLAS() {
+ String osName = System.getProperty("os.name");
+ if (osName == null || osName.isEmpty()) {
+ throw new RuntimeException("Unable to load native implementation");
+ }
+ String osArch = System.getProperty("os.arch");
+ if (osArch == null || osArch.isEmpty()) {
+ throw new RuntimeException("Unable to load native implementation");
+ }
+
+ Path temp;
+ try (InputStream resource = this.getClass().getClassLoader().getResourceAsStream(
+ String.format("resources/native/%s-%s/libnetlibblasjni.so", osName, osArch))) {
+ assert resource != null;
+ Files.copy(resource, temp = Files.createTempFile("libnetlibblasjni.so", "",
+ PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxr-x---"))),
+ StandardCopyOption.REPLACE_EXISTING);
+ temp.toFile().deleteOnExit();
+ } catch (IOException e) {
+ throw new RuntimeException("Unable to load native implementation", e);
+ }
+
+ System.load(temp.toString());
+ }
+
+ public static dev.ludovic.netlib.NativeBLAS getInstance() {
+ return instance;
+ }
+
+ protected native double dasumK(int n, double[] x, int offsetx, int incx);
+
+ protected native float sasumK(int n, float[] x, int offsetx, int incx);
+
+ protected native void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ protected native void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ protected native void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ protected native void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ protected native double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ protected native float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ protected native float sdsdotK(int n, float sb, float[] sx, int offsetsx, int incsx, float[] sy, int offsetsy, int incsy);
+
+ protected native void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ protected native void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ protected native void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ protected native void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ protected native void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ protected native void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ protected native void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ protected native void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ protected native double dnrm2K(int n, double[] x, int offsetx, int incx);
+
+ protected native float snrm2K(int n, float[] x, int offsetx, int incx);
+
+ protected native void drotK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s);
+
+ protected native void srotK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s);
+
+ protected native void drotmK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam);
+
+ protected native void srotmK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam);
+
+ protected native void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam);
+
+ protected native void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam);
+
+ protected native void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ protected native void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ protected native void dscalK(int n, double alpha, double[] x, int offsetx, int incx);
+
+ protected native void sscalK(int n, float alpha, float[] x, int offsetx, int incx);
+
+ protected native void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ protected native void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ protected native void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
+
+ protected native void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta);
+
+ protected native void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta);
+
+ protected native void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta);
+
+ protected native void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
+
+ protected native void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
+
+ protected native void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ protected native void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ protected native void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
+
+ protected native void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
+
+ protected native void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
+
+ protected native void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
+
+ protected native void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
+
+ protected native void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
+
+ protected native void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
+
+ protected native void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
+
+ protected native void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc);
+
+ protected native void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc);
+
+ protected native void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ protected native void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ protected native void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ protected native void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ protected native void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
+
+ protected native void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
+
+ protected native void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
+
+ protected native void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
+
+ protected native void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ protected native void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ protected native void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ protected native void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ protected native void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
+
+ protected native void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
+
+ protected native void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
+
+ protected native void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
+
+ protected native int idamaxK(int n, double[] dx, int offsetdx, int incdx);
+
+ protected native int isamaxK(int n, float[] sx, int offsetsx, int incx);
+}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java
new file mode 100644
index 0000000..443a632
--- /dev/null
+++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java
@@ -0,0 +1,5157 @@
+/*
+ * Copyright 2020, 2021, Ludovic Henry
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ *
+ * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
+ * information or have any questions.
+ */
+
+package dev.ludovic.netlib.blas;
+
+import dev.ludovic.netlib.BLAS;
+
+public class Java8BLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS {
+
+ private static final Java8BLAS instance = new Java8BLAS();
+
+ protected Java8BLAS() {}
+
+ public static dev.ludovic.netlib.JavaBLAS getInstance() {
+ return instance;
+ }
+
+ protected double dasumK(int n, double[] x, int offsetx, int incx) {
+ double sum = 0.0;
+ if (incx == 1) {
+ int ix = 0;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (; ix < loopBound(n, 4); ix += 4) {
+ sum0 += Math.abs(x[offsetx + ix + 0]);
+ sum1 += Math.abs(x[offsetx + ix + 1]);
+ sum2 += Math.abs(x[offsetx + ix + 2]);
+ sum3 += Math.abs(x[offsetx + ix + 3]);
+ }
+ sum += sum0 + sum1 + sum2 + sum3;
+ for (; ix < n; ix += 1) {
+ sum += Math.abs(x[offsetx + ix]);
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
+ sum += Math.abs(x[offsetx + ix]);
+ }
+ }
+ return sum;
+ }
+
+ protected float sasumK(int n, float[] x, int offsetx, int incx) {
+ float sum = 0.0f;
+ if (incx == 1) {
+ int ix = 0;
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ for (; ix < loopBound(n, 4); ix += 4) {
+ sum0 += Math.abs(x[offsetx + ix + 0]);
+ sum1 += Math.abs(x[offsetx + ix + 1]);
+ sum2 += Math.abs(x[offsetx + ix + 2]);
+ sum3 += Math.abs(x[offsetx + ix + 3]);
+ }
+ sum += sum0 + sum1 + sum2 + sum3;
+ for (; ix < n; ix += 1) {
+ sum += Math.abs(x[offsetx + ix]);
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
+ sum += Math.abs(x[offsetx + ix]);
+ }
+ }
+ return sum;
+ }
+
+ protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ y[offsety + iy] += alpha * x[offsetx + ix];
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ y[offsety + iy] += alpha * x[offsetx + ix];
+ }
+ }
+ }
+
+ protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ y[offsety + iy] += alpha * x[offsetx + ix];
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ y[offsety + iy] += alpha * x[offsetx + ix];
+ }
+ }
+ }
+
+ protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ System.arraycopy(x, offsetx, y, offsety, n);
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ y[offsety + iy] = x[offsetx + ix];
+ }
+ }
+ }
+
+ protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ System.arraycopy(x, offsetx, y, offsety, n);
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ y[offsety + iy] = x[offsetx + ix];
+ }
+ }
+ }
+
+ protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ double sum = 0.0;
+ if (incx == 1 && incy == 1) {
+ int ix = 0, iy = 0;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
+ sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0];
+ sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1];
+ sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2];
+ sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3];
+ }
+ sum += sum0 + sum1 + sum2 + sum3;
+ for (; ix < n && iy < n; ix += 1, iy += 1) {
+ sum += x[offsetx + ix] * y[offsety + iy];
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ sum += x[offsetx + ix] * y[offsety + iy];
+ }
+ }
+ return sum;
+ }
+
+ protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ float sum = 0.0f;
+ if (incx == 1 && incy == 1) {
+ int ix = 0, iy = 0;
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
+ sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0];
+ sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1];
+ sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2];
+ sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3];
+ }
+ sum += sum0 + sum1 + sum2 + sum3;
+ for (; ix < n && iy < n; ix += 1, iy += 1) {
+ sum += x[offsetx + ix] * y[offsety + iy];
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ sum += x[offsetx + ix] * y[offsety + iy];
+ }
+ }
+ return sum;
+ }
+
+ protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ double sum = sb;
+ if (incx == 1 && incy == 1) {
+ int ix = 0, iy = 0;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
+ sum0 += (double)x[offsetx + ix + 0] * (double)y[offsety + iy + 0];
+ sum1 += (double)x[offsetx + ix + 1] * (double)y[offsety + iy + 1];
+ sum2 += (double)x[offsetx + ix + 2] * (double)y[offsety + iy + 2];
+ sum3 += (double)x[offsetx + ix + 3] * (double)y[offsety + iy + 3];
+ }
+ sum += sum0 + sum1 + sum2 + sum3;
+ for (; ix < n && iy < n; ix += 1, iy += 1) {
+ sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]);
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]);
+ }
+ }
+ return (float)sum;
+ }
+
+ protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ if (alpha == 0.0) {
+ dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
+ } else if (m * n * k < 100 * 100 * 100) {
+ // The matrices are small and it's faster to do the non-copying version
+ if (lsame("N", transa) && lsame("N", transb)) {
+ dgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("N", transa)) {
+ dgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("N", transb)) {
+ dgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else {
+ dgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ } else {
+ final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3),
+ Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3),
+ Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4);
+
+ assert Krow > 0;
+ assert Kcol > 0;
+ assert Ki > 0;
+
+ double[] packeda = new double[Krow * Ki];
+ double[] packedb = new double[Kcol * Ki];
+ double[] packedc = new double[Kcol * Krow];
+
+ // c = beta * c
+ dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
+ // c += alpha * a * b
+ for (int col = 0; col < n; col += Kcol) {
+ int cols = col, cole = Math.min(col + Kcol, n);
+ for (int i = 0; i < k; i += Ki) {
+ int is = i, ie = Math.min(i + Ki, k);
+ // pack b
+ if (lsame("N", transb)) {
+ dgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
+ } else {
+ dgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
+ }
+ // GEPP
+ for (int row = 0; row < m; row += Krow) {
+ int rows = row, rowe = Math.min(row + Krow, m);
+ // pack A
+ if (lsame("N", transa)) {
+ dgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
+ } else {
+ dgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
+ }
+ // pack C
+ dgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0);
+ // GEBP
+ dgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is,
+ alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow);
+ // unpack C
+ dgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols);
+ }
+ }
+ }
+ }
+ }
+
+ protected void dgemmBeta(int rows, int rowe, int cols, int cole, double beta, double[] c, int offsetc, int ldc) {
+ if (beta != 1.0) {
+ int col = cols;
+ for (; col < loopAlign(cols, cole, 4); col += 1) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0;
+ }
+ }
+ }
+ for (; col < loopBound(cole, 4); col += 4) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0;
+ c[offsetc + row + (col + 1) * ldc] = 0.0;
+ c[offsetc + row + (col + 2) * ldc] = 0.0;
+ c[offsetc + row + (col + 3) * ldc] = 0.0;
+ }
+ }
+ }
+ for (; col < cole; col += 1) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0;
+ }
+ }
+ }
+ }
+ }
+
+ protected void dgecpyNN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m);
+ }
+ for (; col < n; col += 1) {
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
+ }
+ }
+
+ protected void dgecpyNT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc];
+ }
+ for (; row < m; row += 1) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
+ }
+ for (; row < m; row += 1) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ }
+ }
+ }
+
+ protected void dgecpyTN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc];
+ }
+ for (; col < n; col += 1) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
+ }
+ }
+ for (; row < m; row += 1) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
+ }
+ for (; col < n; col += 1) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ }
+ }
+ }
+
+ protected void dgecpyTT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n);
+ }
+ for (; row < m; row += 1) {
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
+ }
+ }
+
+ protected void dgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Tcol = 3, Trow = 3;
+
+ int col = cols;
+ for (; col < loopAlign(cols, cole, Tcol); col += 1) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ double sum00 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ }
+ for (; row < rowe; row += 1) {
+ double sum00 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ }
+ for (; col < loopBound(cole, Tcol); col += Tcol) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum03 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum01 = a0 * b1 + sum01;
+ sum02 = a0 * b2 + sum02;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ dgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ for (; row < rowe; row += 1) {
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum01 = a0 * b1 + sum01;
+ sum02 = a0 * b2 + sum02;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ }
+ }
+ for (; col < cole; col += 1) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ double sum00 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ }
+ for (; row < rowe; row += 1) {
+ double sum00 = 0.0;
+ for (int i = is; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ }
+ }
+
+ protected void dgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Ti = 2;
+
+ assert rowe - rows == 3;
+ assert cole - cols == 3;
+
+ int row = rows;
+ int col = cols;
+ int i = is;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum10 = 0.0;
+ double sum11 = 0.0;
+ double sum12 = 0.0;
+ double sum20 = 0.0;
+ double sum21 = 0.0;
+ double sum22 = 0.0;
+ for (; i < loopAlign(is, ie, Ti); i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ sum01 = a0 * b1 + sum01;
+ sum11 = a1 * b1 + sum11;
+ sum21 = a2 * b1 + sum21;
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ sum02 = a0 * b2 + sum02;
+ sum12 = a1 * b2 + sum12;
+ sum22 = a2 * b2 + sum22;
+ }
+ for (; i < loopBound(ie, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a01 * b00 + sum10;
+ sum20 = a02 * b00 + sum20;
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ sum01 = a00 * b01 + sum01;
+ sum11 = a01 * b01 + sum11;
+ sum21 = a02 * b01 + sum21;
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum02 = a00 * b02 + sum02;
+ sum12 = a01 * b02 + sum12;
+ sum22 = a02 * b02 + sum22;
+ double a10 = a[offseta + (i + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a10 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a12 * b10 + sum20;
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ sum01 = a10 * b11 + sum01;
+ sum11 = a11 * b11 + sum11;
+ sum21 = a12 * b11 + sum21;
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum02 = a10 * b12 + sum02;
+ sum12 = a11 * b12 + sum12;
+ sum22 = a12 * b12 + sum22;
+ }
+ for (; i < ie; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ sum01 = a0 * b1 + sum01;
+ sum11 = a1 * b1 + sum11;
+ sum21 = a2 * b1 + sum21;
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ sum02 = a0 * b2 + sum02;
+ sum12 = a1 * b2 + sum12;
+ sum22 = a2 * b2 + sum22;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc];
+ }
+
+ protected void dgemmNN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum10 = 0.0;
+ double sum11 = 0.0;
+ double sum12 = 0.0;
+ double sum20 = 0.0;
+ double sum21 = 0.0;
+ double sum22 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ double a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ double a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void dgemmNT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum10 = 0.0;
+ double sum11 = 0.0;
+ double sum12 = 0.0;
+ double sum20 = 0.0;
+ double sum21 = 0.0;
+ double sum22 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ double a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ double a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ double a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ double a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void dgemmTN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum10 = 0.0;
+ double sum11 = 0.0;
+ double sum12 = 0.0;
+ double sum20 = 0.0;
+ double sum21 = 0.0;
+ double sum22 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void dgemmTT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ double sum10 = 0.0;
+ double sum11 = 0.0;
+ double sum12 = 0.0;
+ double sum20 = 0.0;
+ double sum21 = 0.0;
+ double sum22 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum01 = 0.0;
+ double sum02 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ double sum00 = 0.0;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ double a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ if (alpha == 0.0f) {
+ sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
+ } else if (m * n * k < 100 * 100 * 100) {
+ // The matrices are small and it's faster to do the non-copying version
+ if (lsame("N", transa) && lsame("N", transb)) {
+ sgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("N", transa)) {
+ sgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("N", transb)) {
+ sgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else {
+ sgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ } else {
+ final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3),
+ Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3),
+ Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4);
+
+ assert Krow > 0;
+ assert Kcol > 0;
+ assert Ki > 0;
+
+ float[] packeda = new float[Krow * Ki];
+ float[] packedb = new float[Kcol * Ki];
+ float[] packedc = new float[Kcol * Krow];
+
+ // c = beta * c
+ sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
+ // c += alpha * a * b
+ for (int col = 0; col < n; col += Kcol) {
+ int cols = col, cole = Math.min(col + Kcol, n);
+ for (int i = 0; i < k; i += Ki) {
+ int is = i, ie = Math.min(i + Ki, k);
+ // pack b
+ if (lsame("N", transb)) {
+ sgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
+ } else {
+ sgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
+ }
+ // GEPP
+ for (int row = 0; row < m; row += Krow) {
+ int rows = row, rowe = Math.min(row + Krow, m);
+ // pack A
+ if (lsame("N", transa)) {
+ sgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
+ } else {
+ sgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
+ }
+ // pack C
+ sgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0);
+ // GEBP
+ sgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is,
+ alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow);
+ // unpack C
+ sgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols);
+ }
+ }
+ }
+ }
+ }
+
+ protected void sgemmBeta(int rows, int rowe, int cols, int cole, float beta, float[] c, int offsetc, int ldc) {
+ if (beta != 1.0f) {
+ int col = cols;
+ for (; col < loopAlign(cols, cole, 4); col += 1) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0f) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0f;
+ }
+ }
+ }
+ for (; col < loopBound(cole, 4); col += 4) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0f) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0f;
+ c[offsetc + row + (col + 1) * ldc] = 0.0f;
+ c[offsetc + row + (col + 2) * ldc] = 0.0f;
+ c[offsetc + row + (col + 3) * ldc] = 0.0f;
+ }
+ }
+ }
+ for (; col < cole; col += 1) {
+ int row = rows;
+ for (; row < rowe; row += 1) {
+ if (beta != 0.0f) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0f;
+ }
+ }
+ }
+ }
+ }
+
+ protected void sgecpyNN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m);
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m);
+ }
+ for (; col < n; col += 1) {
+ System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
+ }
+ }
+
+ protected void sgecpyNT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc];
+ }
+ for (; row < m; row += 1) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
+ dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
+ }
+ for (; row < m; row += 1) {
+ dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
+ }
+ }
+ }
+
+ protected void sgecpyTN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int row = 0;
+ for (; row < loopBound(m, 3); row += 3) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc];
+ }
+ for (; col < n; col += 1) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
+ }
+ }
+ for (; row < m; row += 1) {
+ int col = 0;
+ for (; col < loopBound(n, 3); col += 3) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
+ }
+ for (; col < n; col += 1) {
+ dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
+ }
+ }
+ }
+
+ protected void sgecpyTT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n);
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n);
+ }
+ for (; row < m; row += 1) {
+ System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
+ }
+ }
+
+ protected void sgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Tcol = 3, Trow = 3, Ti = 2;
+
+ int col = cols;
+ for (; col < loopAlign(cols, cole, Tcol); col += 1) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ float sum00 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ }
+ for (; row < rowe; row += 1) {
+ float sum00 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ }
+ for (; col < loopBound(cole, Tcol); col += Tcol) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum03 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum01 = a0 * b1 + sum01;
+ sum02 = a0 * b2 + sum02;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ sgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ for (; row < rowe; row += 1) {
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum01 = a0 * b1 + sum01;
+ sum02 = a0 * b2 + sum02;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ }
+ }
+ for (; col < cole; col += 1) {
+ int row = rows;
+ for (; row < loopAlign(rows, rowe, Trow); row += 1) {
+ float sum00 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ for (; row < loopBound(rowe, Trow); row += Trow) {
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ }
+ for (; row < rowe; row += 1) {
+ float sum00 = 0.0f;
+ for (int i = is; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ }
+ }
+ }
+
+ protected void sgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Ti = 2;
+
+ assert rowe - rows == 3;
+ assert cole - cols == 3;
+
+ int row = rows;
+ int col = cols;
+ int i = is;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum10 = 0.0f;
+ float sum11 = 0.0f;
+ float sum12 = 0.0f;
+ float sum20 = 0.0f;
+ float sum21 = 0.0f;
+ float sum22 = 0.0f;
+ for (; i < loopAlign(is, ie, Ti); i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ sum01 = a0 * b1 + sum01;
+ sum11 = a1 * b1 + sum11;
+ sum21 = a2 * b1 + sum21;
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ sum02 = a0 * b2 + sum02;
+ sum12 = a1 * b2 + sum12;
+ sum22 = a2 * b2 + sum22;
+ }
+ for (; i < loopBound(ie, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a01 * b00 + sum10;
+ sum20 = a02 * b00 + sum20;
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ sum01 = a00 * b01 + sum01;
+ sum11 = a01 * b01 + sum11;
+ sum21 = a02 * b01 + sum21;
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum02 = a00 * b02 + sum02;
+ sum12 = a01 * b02 + sum12;
+ sum22 = a02 * b02 + sum22;
+ float a10 = a[offseta + (i + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a10 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a12 * b10 + sum20;
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ sum01 = a10 * b11 + sum01;
+ sum11 = a11 * b11 + sum11;
+ sum21 = a12 * b11 + sum21;
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum02 = a10 * b12 + sum02;
+ sum12 = a11 * b12 + sum12;
+ sum22 = a12 * b12 + sum22;
+ }
+ for (; i < ie; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ sum00 = a0 * b0 + sum00;
+ sum10 = a1 * b0 + sum10;
+ sum20 = a2 * b0 + sum20;
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ sum01 = a0 * b1 + sum01;
+ sum11 = a1 * b1 + sum11;
+ sum21 = a2 * b1 + sum21;
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ sum02 = a0 * b2 + sum02;
+ sum12 = a1 * b2 + sum12;
+ sum22 = a2 * b2 + sum22;
+ }
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc];
+ }
+
+ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum10 = 0.0f;
+ float sum11 = 0.0f;
+ float sum12 = 0.0f;
+ float sum20 = 0.0f;
+ float sum21 = 0.0f;
+ float sum22 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ float a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ float a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum10 = 0.0f;
+ float sum11 = 0.0f;
+ float sum12 = 0.0f;
+ float sum20 = 0.0f;
+ float sum21 = 0.0f;
+ float sum22 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ float a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float a11 = a[offseta + (row + 1) + (i + 1) * lda];
+ float a21 = a[offseta + (row + 2) + (i + 1) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (i + 0) * lda];
+ float a20 = a[offseta + (row + 2) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ float a01 = a[offseta + (row + 0) + (i + 1) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (row + 0) + (i + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum10 = 0.0f;
+ float sum11 = 0.0f;
+ float sum12 = 0.0f;
+ float sum20 = 0.0f;
+ float sum21 = 0.0f;
+ float sum22 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Trow = 3, Tcol = 3, Ti = 2;
+
+ int col = 0;
+ for (; col < loopBound(n, Tcol); col += Tcol) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ float sum10 = 0.0f;
+ float sum11 = 0.0f;
+ float sum12 = 0.0f;
+ float sum20 = 0.0f;
+ float sum21 = 0.0f;
+ float sum22 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ sum10 = a11 * b10 + sum10;
+ sum11 = a11 * b11 + sum11;
+ sum12 = a11 * b12 + sum12;
+ sum20 = a21 * b10 + sum20;
+ sum21 = a21 * b11 + sum21;
+ sum22 = a21 * b12 + sum22;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ sum10 = a10 * b00 + sum10;
+ sum11 = a10 * b01 + sum11;
+ sum12 = a10 * b02 + sum12;
+ sum20 = a20 * b00 + sum20;
+ sum21 = a20 * b01 + sum21;
+ sum22 = a20 * b02 + sum22;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum01 = 0.0f;
+ float sum02 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
+ float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum01 = a01 * b11 + sum01;
+ sum02 = a01 * b12 + sum02;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
+ float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum01 = a00 * b01 + sum01;
+ sum02 = a00 * b02 + sum02;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, Trow); row += Trow) {
+ int i = 0;
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a21 = a[offseta + (i + 1) + (row + 2) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ sum10 = a11 * b10 + sum10;
+ sum20 = a21 * b10 + sum20;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a20 = a[offseta + (i + 0) + (row + 2) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ sum10 = a10 * b00 + sum10;
+ sum20 = a20 * b00 + sum20;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ }
+ }
+ for (; row < m; row += 1) {
+ int i = 0;
+ float sum00 = 0.0f;
+ for (; i < loopBound(k, Ti); i += Ti) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ float a01 = a[offseta + (i + 1) + (row + 0) * lda];
+ float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
+ sum00 = a01 * b10 + sum00;
+ }
+ for (; i < k; i += 1) {
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
+ sum00 = a00 * b00 + sum00;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ }
+ }
+ }
+ }
+
+ protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (alpha == 0.0) {
+ int len = lsame("N", trans) ? m : n;
+ for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ } else if (lsame("N", trans)) {
+ dgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("T", trans) || lsame("C", trans)) {
+ dgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void dgemvN(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (beta != 1.0) {
+ int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
+ for (; row < m; row += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ }
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4) {
+ int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
+ double alphax0 = alpha * x[offsetx + ix + incx * 0];
+ double alphax1 = alpha * x[offsetx + ix + incx * 1];
+ double alphax2 = alpha * x[offsetx + ix + incx * 2];
+ double alphax3 = alpha * x[offsetx + ix + incx * 3];
+ for (; row < m; row += 1, iy += incy) {
+ y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda]
+ + alphax1 * a[offseta + row + (col + 1) * lda]
+ + alphax2 * a[offseta + row + (col + 2) * lda]
+ + alphax3 * a[offseta + row + (col + 3) * lda];
+ }
+ }
+ for (; col < n; col += 1, ix += incx) {
+ int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
+ double alphax = alpha * x[offsetx + ix];
+ for (; row < m; row += 1, iy += incy) {
+ y[offsety + iy] += alphax * a[offseta + row + col * lda];
+ }
+ }
+ }
+
+ protected void dgemvT(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
+ int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (; row < m; row += 1, ix += incx) {
+ double xix = x[offsetx + ix];
+ sum0 += xix * a[offseta + row + (col + 0) * lda];
+ sum1 += xix * a[offseta + row + (col + 1) * lda];
+ sum2 += xix * a[offseta + row + (col + 2) * lda];
+ sum3 += xix * a[offseta + row + (col + 3) * lda];
+ }
+ if (beta != 0.0) {
+ y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sum0;
+ y[offsety + iy + incy * 1] = alpha * sum1;
+ y[offsety + iy + incy * 2] = alpha * sum2;
+ y[offsety + iy + incy * 3] = alpha * sum3;
+ }
+ }
+ for (; col < n; col += 1, iy += incy) {
+ int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0;
+ double sum = 0.0;
+ for (; row < m; row += 1, ix += incx) {
+ sum += x[offsetx + ix] * a[offseta + row + col * lda];
+ }
+ if (beta != 0.0) {
+ y[offsety + iy] = alpha * sum + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sum;
+ }
+ }
+ }
+
+ protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (alpha == 0.0f) {
+ int len = lsame("N", trans) ? m : n;
+ for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ } else if (lsame("N", trans)) {
+ sgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("T", trans) || lsame("C", trans)) {
+ sgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void sgemvN(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ // y = beta * y
+ for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ // y += alpha * A * x
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0;
+ for (; col < loopBound(n, 8); col += 8, ix += incx * 8) {
+ float alphax0 = alpha * x[offsetx + ix + incx * 0];
+ float alphax1 = alpha * x[offsetx + ix + incx * 1];
+ float alphax2 = alpha * x[offsetx + ix + incx * 2];
+ float alphax3 = alpha * x[offsetx + ix + incx * 3];
+ float alphax4 = alpha * x[offsetx + ix + incx * 4];
+ float alphax5 = alpha * x[offsetx + ix + incx * 5];
+ float alphax6 = alpha * x[offsetx + ix + incx * 6];
+ float alphax7 = alpha * x[offsetx + ix + incx * 7];
+ for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
+ y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda]
+ + alphax1 * a[offseta + row + (col + 1) * lda]
+ + alphax2 * a[offseta + row + (col + 2) * lda]
+ + alphax3 * a[offseta + row + (col + 3) * lda]
+ + alphax4 * a[offseta + row + (col + 4) * lda]
+ + alphax5 * a[offseta + row + (col + 5) * lda]
+ + alphax6 * a[offseta + row + (col + 6) * lda]
+ + alphax7 * a[offseta + row + (col + 7) * lda];
+ }
+ }
+ for (; col < n; col += 1, ix += incx) {
+ float alphax = alpha * x[offsetx + ix];
+ for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
+ y[offsety + iy] += alphax * a[offseta + row + col * lda];
+ }
+ }
+ }
+
+ protected void sgemvT(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 8); col += 8, iy += incy * 8) {
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ float sum4 = 0.0f;
+ float sum5 = 0.0f;
+ float sum6 = 0.0f;
+ float sum7 = 0.0f;
+ for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) {
+ sum0 += x[offsetx + ix] * a[offseta + row + (col + 0) * lda];
+ sum1 += x[offsetx + ix] * a[offseta + row + (col + 1) * lda];
+ sum2 += x[offsetx + ix] * a[offseta + row + (col + 2) * lda];
+ sum3 += x[offsetx + ix] * a[offseta + row + (col + 3) * lda];
+ sum4 += x[offsetx + ix] * a[offseta + row + (col + 4) * lda];
+ sum5 += x[offsetx + ix] * a[offseta + row + (col + 5) * lda];
+ sum6 += x[offsetx + ix] * a[offseta + row + (col + 6) * lda];
+ sum7 += x[offsetx + ix] * a[offseta + row + (col + 7) * lda];
+ }
+ if (beta != 0.0f) {
+ y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3];
+ y[offsety + iy + incy * 4] = alpha * sum4 + beta * y[offsety + iy + incy * 4];
+ y[offsety + iy + incy * 5] = alpha * sum5 + beta * y[offsety + iy + incy * 5];
+ y[offsety + iy + incy * 6] = alpha * sum6 + beta * y[offsety + iy + incy * 6];
+ y[offsety + iy + incy * 7] = alpha * sum7 + beta * y[offsety + iy + incy * 7];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sum0;
+ y[offsety + iy + incy * 1] = alpha * sum1;
+ y[offsety + iy + incy * 2] = alpha * sum2;
+ y[offsety + iy + incy * 3] = alpha * sum3;
+ y[offsety + iy + incy * 4] = alpha * sum4;
+ y[offsety + iy + incy * 5] = alpha * sum5;
+ y[offsety + iy + incy * 6] = alpha * sum6;
+ y[offsety + iy + incy * 7] = alpha * sum7;
+ }
+ }
+ for (; col < n; col += 1, iy += incy) {
+ float sum = 0.0f;
+ for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) {
+ sum += x[offsetx + ix] * a[offseta + row + col * lda];
+ }
+ if (beta != 0.0f) {
+ y[offsety + iy] = alpha * sum + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sum;
+ }
+ }
+ }
+
+ protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
+ double alphayiy0 = alpha * y[offsety + iy + incy * 0];
+ double alphayiy1 = alpha * y[offsety + iy + incy * 1];
+ double alphayiy2 = alpha * y[offsety + iy + incy * 2];
+ double alphayiy3 = alpha * y[offsety + iy + incy * 3];
+ int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
+ for (; row < m; row += 1, jx += incx) {
+ double xjx = x[offsetx + jx];
+ a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx;
+ a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx;
+ a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx;
+ a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx;
+ }
+ }
+ for (; col < n; col += 1, iy += incy) {
+ double alphayiy = alpha * y[offsety + iy];
+ int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
+ for (; row < m; row += 1, jx += incx) {
+ a[offseta + row + col * lda] += alphayiy * x[offsetx + jx];
+ }
+ }
+ }
+
+ protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
+ float alphayiy0 = alpha * y[offsety + iy + incy * 0];
+ float alphayiy1 = alpha * y[offsety + iy + incy * 1];
+ float alphayiy2 = alpha * y[offsety + iy + incy * 2];
+ float alphayiy3 = alpha * y[offsety + iy + incy * 3];
+ int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
+ for (; row < m; row += 1, jx += incx) {
+ float xjx = x[offsetx + jx];
+ a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx;
+ a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx;
+ a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx;
+ a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx;
+ }
+ }
+ for (; col < n; col += 1, iy += incy) {
+ float alphayiy = alpha * y[offsety + iy];
+ int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
+ for (; row < m; row += 1, jx += incx) {
+ a[offseta + row + col * lda] += alphayiy * x[offsetx + jx];
+ }
+ }
+ }
+
+ protected double dnrm2K(int n, double[] x, int offsetx, int incx) {
+ int ix = 0;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ if (incx == 1) {
+ for (; ix < loopBound(n, 4); ix += 4) {
+ double x0 = x[offsetx + ix + 0];
+ double x1 = x[offsetx + ix + 1];
+ double x2 = x[offsetx + ix + 2];
+ double x3 = x[offsetx + ix + 3];
+ sum0 += x0 * x0;
+ sum1 += x1 * x1;
+ sum2 += x2 * x2;
+ sum3 += x3 * x3;
+ }
+ } else {
+ for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) {
+ double x0 = x[offsetx + ix + (0 * incx)];
+ double x1 = x[offsetx + ix + (1 * incx)];
+ double x2 = x[offsetx + ix + (2 * incx)];
+ double x3 = x[offsetx + ix + (3 * incx)];
+ sum0 += x0 * x0;
+ sum1 += x1 * x1;
+ sum2 += x2 * x2;
+ sum3 += x3 * x3;
+ }
+ }
+ double sum = sum0 + sum1 + sum2 + sum3;
+ for (; ix < n * incx; ix += incx) {
+ double x0 = x[offsetx + ix + 0];
+ sum += x0 * x0;
+ }
+ return Math.sqrt(sum);
+ }
+
+ protected float snrm2K(int n, float[] x, int offsetx, int incx) {
+ int ix = 0;
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ if (incx == 1) {
+ for (; ix < loopBound(n, 4); ix += 4) {
+ float x0 = x[offsetx + ix + 0];
+ float x1 = x[offsetx + ix + 1];
+ float x2 = x[offsetx + ix + 2];
+ float x3 = x[offsetx + ix + 3];
+ sum0 += x0 * x0;
+ sum1 += x1 * x1;
+ sum2 += x2 * x2;
+ sum3 += x3 * x3;
+ }
+ } else {
+ for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) {
+ float x0 = x[offsetx + ix + (0 * incx)];
+ float x1 = x[offsetx + ix + (1 * incx)];
+ float x2 = x[offsetx + ix + (2 * incx)];
+ float x3 = x[offsetx + ix + (3 * incx)];
+ sum0 += x0 * x0;
+ sum1 += x1 * x1;
+ sum2 += x2 * x2;
+ sum3 += x3 * x3;
+ }
+ }
+ float sum = sum0 + sum1 + sum2 + sum3;
+ for (; ix < n * incx; ix += incx) {
+ float x0 = x[offsetx + ix + 0];
+ sum += x0 * x0;
+ }
+ return (float)Math.sqrt(sum);
+ }
+
+ protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ double x0 = x[offsetx + ix];
+ double y0 = y[offsety + iy];
+ x[offsetx + ix] = c * x0 + s * y0;
+ y[offsety + iy] = c * y0 - s * x0;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ double x0 = x[offsetx + ix];
+ double y0 = y[offsety + iy];
+ x[offsetx + ix] = c * x0 + s * y0;
+ y[offsety + iy] = c * y0 - s * x0;
+ }
+ }
+ }
+
+ protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ float x0 = x[offsetx + ix];
+ float y0 = y[offsety + iy];
+ x[offsetx + ix] = c * x0 + s * y0;
+ y[offsety + iy] = c * y0 - s * x0;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ float x0 = x[offsetx + ix];
+ float y0 = y[offsety + iy];
+ x[offsetx + ix] = c * x0 + s * y0;
+ y[offsety + iy] = c * y0 - s * x0;
+ }
+ }
+ }
+
+ protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
+ org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+
+ protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
+ org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
+ }
+
+ protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
+ org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam);
+ }
+
+ protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
+ org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam);
+ }
+
+ protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+
+ protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) {
+ if (incx == 1) {
+ for (int ix = 0; ix < n; ix += 1) {
+ x[offsetx + ix] *= alpha;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
+ x[offsetx + ix] *= alpha;
+ }
+ }
+ }
+
+ protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) {
+ if (incx == 1) {
+ for (int ix = 0; ix < n; ix += 1) {
+ x[offsetx + ix] *= alpha;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
+ x[offsetx + ix] *= alpha;
+ }
+ }
+ }
+
+ protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (alpha == 0.0) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ } else if (lsame("U", uplo)) {
+ dspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("L", uplo)) {
+ dspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void dspmvU(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ double sumiy0 = 0.0;
+ double sumiy1 = 0.0;
+ double sumiy2 = 0.0;
+ double sumiy3 = 0.0;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ double a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2];
+ double a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2];
+ double a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2];
+ double a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ double xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ double a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2];
+ double a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2];
+ double a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2];
+ double a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2];
+ double a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2];
+ double a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2];
+ double a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2];
+ double a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2];
+ double a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2];
+ double a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2];
+ double xjx0 = x[offsetx + jx + incx * 0];
+ double xjx1 = x[offsetx + jx + incx * 1];
+ double xjx2 = x[offsetx + jx + incx * 2];
+ double xjx3 = x[offsetx + jx + incx * 3];
+ sumiy0 += xjx0 * a00
+ + xjx1 * a01
+ + xjx2 * a02
+ + xjx3 * a03;
+ sumiy1 += xjx0 * a01
+ + xjx1 * a11
+ + xjx2 * a12
+ + xjx3 * a13;
+ sumiy2 += xjx0 * a02
+ + xjx1 * a12
+ + xjx2 * a22
+ + xjx3 * a23;
+ sumiy3 += xjx0 * a03
+ + xjx1 * a13
+ + xjx2 * a23
+ + xjx3 * a33;
+ if (beta != 0.0) {
+ y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sumiy0;
+ y[offsety + iy + incy * 1] = alpha * sumiy1;
+ y[offsety + iy + incy * 2] = alpha * sumiy2;
+ y[offsety + iy + incy * 3] = alpha * sumiy3;
+ }
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ double alphaxix = alpha * x[offsetx + ix];
+ double sumiy = 0.0;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
+ }
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
+ if (beta != 0.0) {
+ y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sumiy;
+ }
+ }
+ }
+
+ protected void dspmvL(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ // y = beta * y
+ if (beta != 1.0) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ }
+ // y += alpha * A * x
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ double sumiy0 = 0.0;
+ double sumiy1 = 0.0;
+ double sumiy2 = 0.0;
+ double sumiy3 = 0.0;
+ double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2];
+ double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
+ double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
+ double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
+ double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
+ sumiy0 += x0 * a00
+ + x1 * a10
+ + x2 * a20
+ + x3 * a30;
+ sumiy1 += x0 * a10
+ + x1 * a11
+ + x2 * a21
+ + x3 * a31;
+ sumiy2 += x0 * a20
+ + x1 * a21
+ + x2 * a22
+ + x3 * a32;
+ sumiy3 += x0 * a30
+ + x1 * a31
+ + x2 * a32
+ + x3 * a33;
+ int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ double a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ double a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ double a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ double a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ double xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ y[offsety + iy + incy * 0] += alpha * sumiy0;
+ y[offsety + iy + incy * 1] += alpha * sumiy1;
+ y[offsety + iy + incy * 2] += alpha * sumiy2;
+ y[offsety + iy + incy * 3] += alpha * sumiy3;
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ double alphaxix = alpha * x[offsetx + ix];
+ double sumiy = 0.0;
+ sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2];
+ int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2];
+ }
+ y[offsety + iy] += alpha * sumiy;
+ }
+ }
+
+ protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (alpha == 0.0f) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ } else if (lsame("U", uplo)) {
+ sspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("L", uplo)) {
+ sspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void sspmvU(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ float sumiy0 = 0.0f;
+ float sumiy1 = 0.0f;
+ float sumiy2 = 0.0f;
+ float sumiy3 = 0.0f;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ float a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2];
+ float a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2];
+ float a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2];
+ float a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ float xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ float a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2];
+ float a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2];
+ float a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2];
+ float a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2];
+ float a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2];
+ float a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2];
+ float a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2];
+ float a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2];
+ float a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2];
+ float a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2];
+ float xjx0 = x[offsetx + jx + incx * 0];
+ float xjx1 = x[offsetx + jx + incx * 1];
+ float xjx2 = x[offsetx + jx + incx * 2];
+ float xjx3 = x[offsetx + jx + incx * 3];
+ sumiy0 += xjx0 * a00
+ + xjx1 * a01
+ + xjx2 * a02
+ + xjx3 * a03;
+ sumiy1 += xjx0 * a01
+ + xjx1 * a11
+ + xjx2 * a12
+ + xjx3 * a13;
+ sumiy2 += xjx0 * a02
+ + xjx1 * a12
+ + xjx2 * a22
+ + xjx3 * a23;
+ sumiy3 += xjx0 * a03
+ + xjx1 * a13
+ + xjx2 * a23
+ + xjx3 * a33;
+ if (beta != 0.0f) {
+ y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sumiy0;
+ y[offsety + iy + incy * 1] = alpha * sumiy1;
+ y[offsety + iy + incy * 2] = alpha * sumiy2;
+ y[offsety + iy + incy * 3] = alpha * sumiy3;
+ }
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ float alphaxix = alpha * x[offsetx + ix];
+ float sumiy = 0.0f;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
+ }
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
+ if (beta != 0.0f) {
+ y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sumiy;
+ }
+ }
+ }
+
+ protected void sspmvL(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ // y = beta * y
+ if (beta != 1.0f) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ }
+ // y += alpha * A * x
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ float sumiy0 = 0.0f;
+ float sumiy1 = 0.0f;
+ float sumiy2 = 0.0f;
+ float sumiy3 = 0.0f;
+ float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2];
+ float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
+ float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
+ float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
+ float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
+ sumiy0 += x0 * a00
+ + x1 * a10
+ + x2 * a20
+ + x3 * a30;
+ sumiy1 += x0 * a10
+ + x1 * a11
+ + x2 * a21
+ + x3 * a31;
+ sumiy2 += x0 * a20
+ + x1 * a21
+ + x2 * a22
+ + x3 * a32;
+ sumiy3 += x0 * a30
+ + x1 * a31
+ + x2 * a32
+ + x3 * a33;
+ int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ float a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2];
+ float a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2];
+ float a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2];
+ float a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ float xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ y[offsety + iy + incy * 0] += alpha * sumiy0;
+ y[offsety + iy + incy * 1] += alpha * sumiy1;
+ y[offsety + iy + incy * 2] += alpha * sumiy2;
+ y[offsety + iy + incy * 3] += alpha * sumiy3;
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ float alphaxix = alpha * x[offsetx + ix];
+ float sumiy = 0.0f;
+ sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2];
+ int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2];
+ }
+ y[offsety + iy] += alpha * sumiy;
+ }
+ }
+
+ protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
+ org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+
+ protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
+ org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
+ }
+
+ protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
+ org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+
+ protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
+ org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
+ }
+
+ protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ double tmp = y[offsety + iy];
+ y[offsety + iy] = x[offsetx + ix];
+ x[offsetx + ix] = tmp;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ double tmp = y[offsety + iy];
+ y[offsety + iy] = x[offsetx + ix];
+ x[offsetx + ix] = tmp;
+ }
+ }
+ }
+
+ protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
+ if (incx == 1 && incy == 1) {
+ for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
+ float tmp = y[offsety + iy];
+ y[offsety + iy] = x[offsetx + ix];
+ x[offsetx + ix] = tmp;
+ }
+ } else {
+ for (int ix = incx < 0 ? (n - 1) * -incx : 0,
+ iy = incy < 0 ? (n - 1) * -incy : 0;
+ (incx < 0 ? ix >= 0 : ix < n * incx)
+ && (incy < 0 ? iy >= 0 : iy < n * incy);
+ ix += incx, iy += incy) {
+ float tmp = y[offsety + iy];
+ y[offsety + iy] = x[offsetx + ix];
+ x[offsetx + ix] = tmp;
+ }
+ }
+ }
+
+ protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ if (alpha == 0.0) {
+ // C := beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = 0;
+ for (; row < m; row += 1) {
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = 0.0;
+ c[offsetc + row + (col + 1) * ldc] = 0.0;
+ c[offsetc + row + (col + 2) * ldc] = 0.0;
+ c[offsetc + row + (col + 3) * ldc] = 0.0;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < m; row += 1) {
+ if (beta != 0.0) {
+ c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc];
+ } else {
+ c[offsetc + row + col * ldc] = 0.0;
+ }
+ }
+ }
+ } else if (lsame("L", side) && lsame("U", uplo)) {
+ dsymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("L", side) && lsame("L", uplo)) {
+ dsymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("R", side) && lsame("U", uplo)) {
+ dsymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("R", side) && lsame("L", uplo)) {
+ dsymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ }
+
+ protected void dsymmLU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ // C := alpha*A*B + beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ double sum30 = 0.0;
+ double sum01 = 0.0;
+ double sum11 = 0.0;
+ double sum21 = 0.0;
+ double sum31 = 0.0;
+ double sum02 = 0.0;
+ double sum12 = 0.0;
+ double sum22 = 0.0;
+ double sum32 = 0.0;
+ double sum03 = 0.0;
+ double sum13 = 0.0;
+ double sum23 = 0.0;
+ double sum33 = 0.0;
+ double alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb];
+ double alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb];
+ double alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb];
+ double alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb];
+ double alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb];
+ double alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb];
+ double alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb];
+ double alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb];
+ double alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb];
+ double alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb];
+ double alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb];
+ double alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb];
+ double alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb];
+ double alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb];
+ double alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb];
+ double alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
+ + alphab10 * a1
+ + alphab20 * a2
+ + alphab30 * a3;
+ c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
+ + alphab11 * a1
+ + alphab21 * a2
+ + alphab31 * a3;
+ c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
+ + alphab12 * a1
+ + alphab22 * a2
+ + alphab32 * a3;
+ c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
+ + alphab13 * a1
+ + alphab23 * a2
+ + alphab33 * a3;
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ double b3 = b[offsetb + i + (col + 3) * ldb];
+ sum00 += a0 * b0;
+ sum10 += a1 * b0;
+ sum20 += a2 * b0;
+ sum30 += a3 * b0;
+ sum01 += a0 * b1;
+ sum11 += a1 * b1;
+ sum21 += a2 * b1;
+ sum31 += a3 * b1;
+ sum02 += a0 * b2;
+ sum12 += a1 * b2;
+ sum22 += a2 * b2;
+ sum32 += a3 * b2;
+ sum03 += a0 * b3;
+ sum13 += a1 * b3;
+ sum23 += a2 * b3;
+ sum33 += a3 * b3;
+ }
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ double a03 = a[offseta + (i + 0) + (row + 3) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ double a13 = a[offseta + (i + 1) + (row + 3) * lda];
+ double a22 = a[offseta + (i + 2) + (row + 2) * lda];
+ double a23 = a[offseta + (i + 2) + (row + 3) * lda];
+ double a33 = a[offseta + (i + 3) + (row + 3) * lda];
+ double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ double b20 = b[offsetb + (i + 2) + (col + 0) * ldb];
+ double b30 = b[offsetb + (i + 3) + (col + 0) * ldb];
+ double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ double b21 = b[offsetb + (i + 2) + (col + 1) * ldb];
+ double b31 = b[offsetb + (i + 3) + (col + 1) * ldb];
+ double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ double b22 = b[offsetb + (i + 2) + (col + 2) * ldb];
+ double b32 = b[offsetb + (i + 3) + (col + 2) * ldb];
+ double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
+ double b13 = b[offsetb + (i + 1) + (col + 3) * ldb];
+ double b23 = b[offsetb + (i + 2) + (col + 3) * ldb];
+ double b33 = b[offsetb + (i + 3) + (col + 3) * ldb];
+ sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
+ sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
+ sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30;
+ sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30;
+ sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
+ sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
+ sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31;
+ sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31;
+ sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
+ sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
+ sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32;
+ sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32;
+ sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
+ sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
+ sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33;
+ sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33;
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
+ }
+ }
+ for (; row < m; row += 1) {
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
+ double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
+ double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
+ double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ double a0 = a[offseta + i + row * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab0 * a0;
+ c[offsetc + i + (col + 1) * ldc] += alphab1 * a0;
+ c[offsetc + i + (col + 2) * ldc] += alphab2 * a0;
+ c[offsetc + i + (col + 3) * ldc] += alphab3 * a0;
+ sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
+ }
+ double a0 = a[offseta + i + row * lda];
+ sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
+ double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
+ double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
+ double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0
+ + alphab1 * a1
+ + alphab2 * a2
+ + alphab3 * a3;
+ double b0 = b[offsetb + i + col * ldb];
+ sum0 += b0 * a0;
+ sum1 += b0 * a1;
+ sum2 += b0 * a2;
+ sum3 += b0 * a3;
+ }
+ double a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ double a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ double a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ double a03 = a[offseta + (i + 0) + (row + 3) * lda];
+ double a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ double a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ double a13 = a[offseta + (i + 1) + (row + 3) * lda];
+ double a22 = a[offseta + (i + 2) + (row + 2) * lda];
+ double a23 = a[offseta + (i + 2) + (row + 3) * lda];
+ double a33 = a[offseta + (i + 3) + (row + 3) * lda];
+ double b0 = b[offsetb + (i + 0) + col * ldb];
+ double b1 = b[offsetb + (i + 1) + col * ldb];
+ double b2 = b[offsetb + (i + 2) + col * ldb];
+ double b3 = b[offsetb + (i + 3) + col * ldb];
+ sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03;
+ sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13;
+ sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23;
+ sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33;
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
+ } else {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
+ }
+ }
+ for (; row < m; row += 1) {
+ double alphab = alpha * b[offsetb + row + col * ldb];
+ double sum = 0.0;
+ int i = 0;
+ for (; i < row; i += 1) {
+ double aval = a[offseta + i + row * lda];
+ c[offsetc + i + col * ldc] += alphab * aval;
+ sum += b[offsetb + i + col * ldb] * aval;
+ }
+ sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda];
+ if (beta != 0.0) {
+ c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc];
+ } else {
+ c[offsetc + row + col * ldc] = alpha * sum;
+ }
+ }
+ }
+ }
+
+ protected void dsymmLL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ final int Srow = 4;
+ // C := alpha*A*B + beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = m - 1;
+ for (; row >= loopBound(m - 1, Srow); row -= 1) {
+ double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
+ double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
+ double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
+ double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda];
+ sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda];
+ sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda];
+ sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda];
+ int i = row + 1;
+ for (; i < m; i += 1) {
+ double airow = a[offseta + i + row * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab0 * airow;
+ c[offsetc + i + (col + 1) * ldc] += alphab1 * airow;
+ c[offsetc + i + (col + 2) * ldc] += alphab2 * airow;
+ c[offsetc + i + (col + 3) * ldc] += alphab3 * airow;
+ sum0 += b[offsetb + i + (col + 0) * ldb] * airow;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * airow;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * airow;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * airow;
+ }
+ if (beta != 0.0) {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
+ }
+ }
+ for (row -= Srow - 1; row >= 0; row -= Srow) {
+ double a00 = a[offseta + (row + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (row + 1) + (row + 1) * lda];
+ double a20 = a[offseta + (row + 2) + (row + 0) * lda];
+ double a21 = a[offseta + (row + 2) + (row + 1) * lda];
+ double a22 = a[offseta + (row + 2) + (row + 2) * lda];
+ double a30 = a[offseta + (row + 3) + (row + 0) * lda];
+ double a31 = a[offseta + (row + 3) + (row + 1) * lda];
+ double a32 = a[offseta + (row + 3) + (row + 2) * lda];
+ double a33 = a[offseta + (row + 3) + (row + 3) * lda];
+ double b00 = b[offsetb + (row + 0) + (col + 0) * ldb];
+ double b10 = b[offsetb + (row + 1) + (col + 0) * ldb];
+ double b20 = b[offsetb + (row + 2) + (col + 0) * ldb];
+ double b30 = b[offsetb + (row + 3) + (col + 0) * ldb];
+ double b01 = b[offsetb + (row + 0) + (col + 1) * ldb];
+ double b11 = b[offsetb + (row + 1) + (col + 1) * ldb];
+ double b21 = b[offsetb + (row + 2) + (col + 1) * ldb];
+ double b31 = b[offsetb + (row + 3) + (col + 1) * ldb];
+ double b02 = b[offsetb + (row + 0) + (col + 2) * ldb];
+ double b12 = b[offsetb + (row + 1) + (col + 2) * ldb];
+ double b22 = b[offsetb + (row + 2) + (col + 2) * ldb];
+ double b32 = b[offsetb + (row + 3) + (col + 2) * ldb];
+ double b03 = b[offsetb + (row + 0) + (col + 3) * ldb];
+ double b13 = b[offsetb + (row + 1) + (col + 3) * ldb];
+ double b23 = b[offsetb + (row + 2) + (col + 3) * ldb];
+ double b33 = b[offsetb + (row + 3) + (col + 3) * ldb];
+ double alphab00 = alpha * b00;
+ double alphab10 = alpha * b10;
+ double alphab20 = alpha * b20;
+ double alphab30 = alpha * b30;
+ double alphab01 = alpha * b01;
+ double alphab11 = alpha * b11;
+ double alphab21 = alpha * b21;
+ double alphab31 = alpha * b31;
+ double alphab02 = alpha * b02;
+ double alphab12 = alpha * b12;
+ double alphab22 = alpha * b22;
+ double alphab32 = alpha * b32;
+ double alphab03 = alpha * b03;
+ double alphab13 = alpha * b13;
+ double alphab23 = alpha * b23;
+ double alphab33 = alpha * b33;
+ double sum00 = 0.0;
+ double sum10 = 0.0;
+ double sum20 = 0.0;
+ double sum30 = 0.0;
+ double sum01 = 0.0;
+ double sum11 = 0.0;
+ double sum21 = 0.0;
+ double sum31 = 0.0;
+ double sum02 = 0.0;
+ double sum12 = 0.0;
+ double sum22 = 0.0;
+ double sum32 = 0.0;
+ double sum03 = 0.0;
+ double sum13 = 0.0;
+ double sum23 = 0.0;
+ double sum33 = 0.0;
+ sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30;
+ sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31;
+ sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32;
+ sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33;
+ sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30;
+ sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31;
+ sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32;
+ sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33;
+ sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30;
+ sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31;
+ sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32;
+ sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33;
+ sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30;
+ sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31;
+ sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32;
+ sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33;
+ int i = row + 4;
+ for (; i < m; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
+ + alphab10 * a1
+ + alphab20 * a2
+ + alphab30 * a3;
+ c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
+ + alphab11 * a1
+ + alphab21 * a2
+ + alphab31 * a3;
+ c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
+ + alphab12 * a1
+ + alphab22 * a2
+ + alphab32 * a3;
+ c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
+ + alphab13 * a1
+ + alphab23 * a2
+ + alphab33 * a3;
+ double b0 = b[offsetb + i + (col + 0) * ldb];
+ double b1 = b[offsetb + i + (col + 1) * ldb];
+ double b2 = b[offsetb + i + (col + 2) * ldb];
+ double b3 = b[offsetb + i + (col + 3) * ldb];
+ sum00 += b0 * a0;
+ sum10 += b0 * a1;
+ sum20 += b0 * a2;
+ sum30 += b0 * a3;
+ sum01 += b1 * a0;
+ sum11 += b1 * a1;
+ sum21 += b1 * a2;
+ sum31 += b1 * a3;
+ sum02 += b2 * a0;
+ sum12 += b2 * a1;
+ sum22 += b2 * a2;
+ sum32 += b2 * a3;
+ sum03 += b3 * a0;
+ sum13 += b3 * a1;
+ sum23 += b3 * a2;
+ sum33 += b3 * a3;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = m - 1;
+ for (; row >= loopBound(m - 1, Srow); row -= 1) {
+ double alphab0 = alpha * b[offsetb + row + col * ldb];
+ double sum0 = 0.0;
+ sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda];
+ int i = row + 1;
+ for (; i < m; i += 1) {
+ double a0 = a[offseta + i + row * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0;
+ sum0 += b[offsetb + i + col * ldb] * a0;
+ }
+ if (beta != 0.0) {
+ c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc];
+ } else {
+ c[offsetc + row + col * ldc] = alpha * sum0;
+ }
+ }
+ for (row -= Srow - 1; row >= 0; row -= Srow) {
+ double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
+ double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
+ double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
+ double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
+ double a00 = a[offseta + (row + 0) + (row + 0) * lda];
+ double a10 = a[offseta + (row + 1) + (row + 0) * lda];
+ double a11 = a[offseta + (row + 1) + (row + 1) * lda];
+ double a20 = a[offseta + (row + 2) + (row + 0) * lda];
+ double a21 = a[offseta + (row + 2) + (row + 1) * lda];
+ double a22 = a[offseta + (row + 2) + (row + 2) * lda];
+ double a30 = a[offseta + (row + 3) + (row + 0) * lda];
+ double a31 = a[offseta + (row + 3) + (row + 1) * lda];
+ double a32 = a[offseta + (row + 3) + (row + 2) * lda];
+ double a33 = a[offseta + (row + 3) + (row + 3) * lda];
+ double b0 = b[offsetb + (row + 0) + col * ldb];
+ double b1 = b[offsetb + (row + 1) + col * ldb];
+ double b2 = b[offsetb + (row + 2) + col * ldb];
+ double b3 = b[offsetb + (row + 3) + col * ldb];
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30;
+ sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31;
+ sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32;
+ sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33;
+ int i = row + 4;
+ for (; i < m; i += 1) {
+ double a0 = a[offseta + i + (row + 0) * lda];
+ double a1 = a[offseta + i + (row + 1) * lda];
+ double a2 = a[offseta + i + (row + 2) * lda];
+ double a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0
+ + alphab1 * a1
+ + alphab2 * a2
+ + alphab3 * a3;
+ double bicol = b[offsetb + i + col * ldb];
+ sum0 += bicol * a0;
+ sum1 += bicol * a1;
+ sum2 += bicol * a2;
+ sum3 += bicol * a3;
+ }
+ if (beta != 0.0) {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
+ } else {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
+ }
+ }
+ }
+ }
+
+ protected void dsymmRU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ // C := alpha*B*A + beta*C
+ org.netlib.blas.Dsymm.dsymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void dsymmRL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ // C := alpha*B*A + beta*C
+ org.netlib.blas.Dsymm.dsymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ if (alpha == 0.0f) {
+ // C := beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = 0;
+ for (; row < m; row += 1) {
+ c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < m; row += 1) {
+ c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc];
+ }
+ }
+ } else if (lsame("L", side) && lsame("U", uplo)) {
+ ssymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("L", side) && lsame("L", uplo)) {
+ ssymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("R", side) && lsame("U", uplo)) {
+ ssymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ } else if (lsame("R", side) && lsame("L", uplo)) {
+ ssymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+ }
+
+ protected void ssymmLU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ // C := alpha*A*B + beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ float sum30 = 0.0f;
+ float sum01 = 0.0f;
+ float sum11 = 0.0f;
+ float sum21 = 0.0f;
+ float sum31 = 0.0f;
+ float sum02 = 0.0f;
+ float sum12 = 0.0f;
+ float sum22 = 0.0f;
+ float sum32 = 0.0f;
+ float sum03 = 0.0f;
+ float sum13 = 0.0f;
+ float sum23 = 0.0f;
+ float sum33 = 0.0f;
+ float alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb];
+ float alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb];
+ float alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb];
+ float alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb];
+ float alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb];
+ float alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb];
+ float alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb];
+ float alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb];
+ float alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb];
+ float alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb];
+ float alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb];
+ float alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb];
+ float alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb];
+ float alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb];
+ float alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb];
+ float alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
+ + alphab10 * a1
+ + alphab20 * a2
+ + alphab30 * a3;
+ c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
+ + alphab11 * a1
+ + alphab21 * a2
+ + alphab31 * a3;
+ c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
+ + alphab12 * a1
+ + alphab22 * a2
+ + alphab32 * a3;
+ c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
+ + alphab13 * a1
+ + alphab23 * a2
+ + alphab33 * a3;
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ float b3 = b[offsetb + i + (col + 3) * ldb];
+ sum00 += a0 * b0;
+ sum10 += a1 * b0;
+ sum20 += a2 * b0;
+ sum30 += a3 * b0;
+ sum01 += a0 * b1;
+ sum11 += a1 * b1;
+ sum21 += a2 * b1;
+ sum31 += a3 * b1;
+ sum02 += a0 * b2;
+ sum12 += a1 * b2;
+ sum22 += a2 * b2;
+ sum32 += a3 * b2;
+ sum03 += a0 * b3;
+ sum13 += a1 * b3;
+ sum23 += a2 * b3;
+ sum33 += a3 * b3;
+ }
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ float a03 = a[offseta + (i + 0) + (row + 3) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ float a13 = a[offseta + (i + 1) + (row + 3) * lda];
+ float a22 = a[offseta + (i + 2) + (row + 2) * lda];
+ float a23 = a[offseta + (i + 2) + (row + 3) * lda];
+ float a33 = a[offseta + (i + 3) + (row + 3) * lda];
+ float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
+ float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
+ float b20 = b[offsetb + (i + 2) + (col + 0) * ldb];
+ float b30 = b[offsetb + (i + 3) + (col + 0) * ldb];
+ float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
+ float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
+ float b21 = b[offsetb + (i + 2) + (col + 1) * ldb];
+ float b31 = b[offsetb + (i + 3) + (col + 1) * ldb];
+ float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
+ float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
+ float b22 = b[offsetb + (i + 2) + (col + 2) * ldb];
+ float b32 = b[offsetb + (i + 3) + (col + 2) * ldb];
+ float b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
+ float b13 = b[offsetb + (i + 1) + (col + 3) * ldb];
+ float b23 = b[offsetb + (i + 2) + (col + 3) * ldb];
+ float b33 = b[offsetb + (i + 3) + (col + 3) * ldb];
+ sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
+ sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
+ sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30;
+ sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30;
+ sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
+ sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
+ sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31;
+ sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31;
+ sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
+ sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
+ sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32;
+ sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32;
+ sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
+ sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
+ sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33;
+ sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33;
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
+ }
+ }
+ for (; row < m; row += 1) {
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
+ float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
+ float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
+ float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ float a0 = a[offseta + i + row * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab0 * a0;
+ c[offsetc + i + (col + 1) * ldc] += alphab1 * a0;
+ c[offsetc + i + (col + 2) * ldc] += alphab2 * a0;
+ c[offsetc + i + (col + 3) * ldc] += alphab3 * a0;
+ sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
+ }
+ float a0 = a[offseta + i + row * lda];
+ sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
+ if (beta != 0.0f) {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = 0;
+ for (; row < loopBound(m, 4); row += 4) {
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
+ float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
+ float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
+ float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
+ int i = 0;
+ for (; i < row; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0
+ + alphab1 * a1
+ + alphab2 * a2
+ + alphab3 * a3;
+ float b0 = b[offsetb + i + col * ldb];
+ sum0 += b0 * a0;
+ sum1 += b0 * a1;
+ sum2 += b0 * a2;
+ sum3 += b0 * a3;
+ }
+ float a00 = a[offseta + (i + 0) + (row + 0) * lda];
+ float a01 = a[offseta + (i + 0) + (row + 1) * lda];
+ float a02 = a[offseta + (i + 0) + (row + 2) * lda];
+ float a03 = a[offseta + (i + 0) + (row + 3) * lda];
+ float a11 = a[offseta + (i + 1) + (row + 1) * lda];
+ float a12 = a[offseta + (i + 1) + (row + 2) * lda];
+ float a13 = a[offseta + (i + 1) + (row + 3) * lda];
+ float a22 = a[offseta + (i + 2) + (row + 2) * lda];
+ float a23 = a[offseta + (i + 2) + (row + 3) * lda];
+ float a33 = a[offseta + (i + 3) + (row + 3) * lda];
+ float b0 = b[offsetb + (i + 0) + col * ldb];
+ float b1 = b[offsetb + (i + 1) + col * ldb];
+ float b2 = b[offsetb + (i + 2) + col * ldb];
+ float b3 = b[offsetb + (i + 3) + col * ldb];
+ sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03;
+ sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13;
+ sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23;
+ sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33;
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
+ } else {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
+ }
+ }
+ for (; row < m; row += 1) {
+ float alphab = alpha * b[offsetb + row + col * ldb];
+ float sum = 0.0f;
+ int i = 0;
+ for (; i < row; i += 1) {
+ float aval = a[offseta + i + row * lda];
+ c[offsetc + i + col * ldc] += alphab * aval;
+ sum += b[offsetb + i + col * ldb] * aval;
+ }
+ sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda];
+ if (beta != 0.0f) {
+ c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc];
+ } else {
+ c[offsetc + row + col * ldc] = alpha * sum;
+ }
+ }
+ }
+ }
+
+ protected void ssymmLL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ final int Srow = 4;
+ // C := alpha*A*B + beta*C
+ int col = 0;
+ for (; col < loopBound(n, 4); col += 4) {
+ int row = m - 1;
+ for (; row >= loopBound(m - 1, Srow); row -= 1) {
+ float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
+ float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
+ float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
+ float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda];
+ sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda];
+ sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda];
+ sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda];
+ int i = row + 1;
+ for (; i < m; i += 1) {
+ float airow = a[offseta + i + row * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab0 * airow;
+ c[offsetc + i + (col + 1) * ldc] += alphab1 * airow;
+ c[offsetc + i + (col + 2) * ldc] += alphab2 * airow;
+ c[offsetc + i + (col + 3) * ldc] += alphab3 * airow;
+ sum0 += b[offsetb + i + (col + 0) * ldb] * airow;
+ sum1 += b[offsetb + i + (col + 1) * ldb] * airow;
+ sum2 += b[offsetb + i + (col + 2) * ldb] * airow;
+ sum3 += b[offsetb + i + (col + 3) * ldb] * airow;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
+ } else {
+ c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
+ c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
+ c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
+ c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
+ }
+ }
+ for (row -= Srow - 1; row >= 0; row -= Srow) {
+ float a00 = a[offseta + (row + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (row + 1) + (row + 1) * lda];
+ float a20 = a[offseta + (row + 2) + (row + 0) * lda];
+ float a21 = a[offseta + (row + 2) + (row + 1) * lda];
+ float a22 = a[offseta + (row + 2) + (row + 2) * lda];
+ float a30 = a[offseta + (row + 3) + (row + 0) * lda];
+ float a31 = a[offseta + (row + 3) + (row + 1) * lda];
+ float a32 = a[offseta + (row + 3) + (row + 2) * lda];
+ float a33 = a[offseta + (row + 3) + (row + 3) * lda];
+ float b00 = b[offsetb + (row + 0) + (col + 0) * ldb];
+ float b10 = b[offsetb + (row + 1) + (col + 0) * ldb];
+ float b20 = b[offsetb + (row + 2) + (col + 0) * ldb];
+ float b30 = b[offsetb + (row + 3) + (col + 0) * ldb];
+ float b01 = b[offsetb + (row + 0) + (col + 1) * ldb];
+ float b11 = b[offsetb + (row + 1) + (col + 1) * ldb];
+ float b21 = b[offsetb + (row + 2) + (col + 1) * ldb];
+ float b31 = b[offsetb + (row + 3) + (col + 1) * ldb];
+ float b02 = b[offsetb + (row + 0) + (col + 2) * ldb];
+ float b12 = b[offsetb + (row + 1) + (col + 2) * ldb];
+ float b22 = b[offsetb + (row + 2) + (col + 2) * ldb];
+ float b32 = b[offsetb + (row + 3) + (col + 2) * ldb];
+ float b03 = b[offsetb + (row + 0) + (col + 3) * ldb];
+ float b13 = b[offsetb + (row + 1) + (col + 3) * ldb];
+ float b23 = b[offsetb + (row + 2) + (col + 3) * ldb];
+ float b33 = b[offsetb + (row + 3) + (col + 3) * ldb];
+ float alphab00 = alpha * b00;
+ float alphab10 = alpha * b10;
+ float alphab20 = alpha * b20;
+ float alphab30 = alpha * b30;
+ float alphab01 = alpha * b01;
+ float alphab11 = alpha * b11;
+ float alphab21 = alpha * b21;
+ float alphab31 = alpha * b31;
+ float alphab02 = alpha * b02;
+ float alphab12 = alpha * b12;
+ float alphab22 = alpha * b22;
+ float alphab32 = alpha * b32;
+ float alphab03 = alpha * b03;
+ float alphab13 = alpha * b13;
+ float alphab23 = alpha * b23;
+ float alphab33 = alpha * b33;
+ float sum00 = 0.0f;
+ float sum10 = 0.0f;
+ float sum20 = 0.0f;
+ float sum30 = 0.0f;
+ float sum01 = 0.0f;
+ float sum11 = 0.0f;
+ float sum21 = 0.0f;
+ float sum31 = 0.0f;
+ float sum02 = 0.0f;
+ float sum12 = 0.0f;
+ float sum22 = 0.0f;
+ float sum32 = 0.0f;
+ float sum03 = 0.0f;
+ float sum13 = 0.0f;
+ float sum23 = 0.0f;
+ float sum33 = 0.0f;
+ sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30;
+ sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31;
+ sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32;
+ sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33;
+ sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30;
+ sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31;
+ sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32;
+ sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33;
+ sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30;
+ sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31;
+ sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32;
+ sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33;
+ sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30;
+ sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31;
+ sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32;
+ sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33;
+ int i = row + 4;
+ for (; i < m; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
+ + alphab10 * a1
+ + alphab20 * a2
+ + alphab30 * a3;
+ c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
+ + alphab11 * a1
+ + alphab21 * a2
+ + alphab31 * a3;
+ c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
+ + alphab12 * a1
+ + alphab22 * a2
+ + alphab32 * a3;
+ c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
+ + alphab13 * a1
+ + alphab23 * a2
+ + alphab33 * a3;
+ float b0 = b[offsetb + i + (col + 0) * ldb];
+ float b1 = b[offsetb + i + (col + 1) * ldb];
+ float b2 = b[offsetb + i + (col + 2) * ldb];
+ float b3 = b[offsetb + i + (col + 3) * ldb];
+ sum00 += b0 * a0;
+ sum10 += b0 * a1;
+ sum20 += b0 * a2;
+ sum30 += b0 * a3;
+ sum01 += b1 * a0;
+ sum11 += b1 * a1;
+ sum21 += b1 * a2;
+ sum31 += b1 * a3;
+ sum02 += b2 * a0;
+ sum12 += b2 * a1;
+ sum22 += b2 * a2;
+ sum32 += b2 * a3;
+ sum03 += b3 * a0;
+ sum13 += b3 * a1;
+ sum23 += b3 * a2;
+ sum33 += b3 * a3;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
+ } else {
+ c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
+ c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
+ c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
+ c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
+ c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
+ c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
+ c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
+ c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
+ c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
+ c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
+ c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
+ c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
+ c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
+ c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
+ c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
+ c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
+ }
+ }
+ }
+ for (; col < n; col += 1) {
+ int row = m - 1;
+ for (; row >= loopBound(m - 1, Srow); row -= 1) {
+ float alphab0 = alpha * b[offsetb + row + col * ldb];
+ float sum0 = 0.0f;
+ sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda];
+ int i = row + 1;
+ for (; i < m; i += 1) {
+ float a0 = a[offseta + i + row * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0;
+ sum0 += b[offsetb + i + col * ldb] * a0;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc];
+ } else {
+ c[offsetc + row + col * ldc] = alpha * sum0;
+ }
+ }
+ for (row -= Srow - 1; row >= 0; row -= Srow) {
+ float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
+ float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
+ float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
+ float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
+ float a00 = a[offseta + (row + 0) + (row + 0) * lda];
+ float a10 = a[offseta + (row + 1) + (row + 0) * lda];
+ float a11 = a[offseta + (row + 1) + (row + 1) * lda];
+ float a20 = a[offseta + (row + 2) + (row + 0) * lda];
+ float a21 = a[offseta + (row + 2) + (row + 1) * lda];
+ float a22 = a[offseta + (row + 2) + (row + 2) * lda];
+ float a30 = a[offseta + (row + 3) + (row + 0) * lda];
+ float a31 = a[offseta + (row + 3) + (row + 1) * lda];
+ float a32 = a[offseta + (row + 3) + (row + 2) * lda];
+ float a33 = a[offseta + (row + 3) + (row + 3) * lda];
+ float b0 = b[offsetb + (row + 0) + col * ldb];
+ float b1 = b[offsetb + (row + 1) + col * ldb];
+ float b2 = b[offsetb + (row + 2) + col * ldb];
+ float b3 = b[offsetb + (row + 3) + col * ldb];
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+ float sum2 = 0.0f;
+ float sum3 = 0.0f;
+ sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30;
+ sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31;
+ sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32;
+ sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33;
+ int i = row + 4;
+ for (; i < m; i += 1) {
+ float a0 = a[offseta + i + (row + 0) * lda];
+ float a1 = a[offseta + i + (row + 1) * lda];
+ float a2 = a[offseta + i + (row + 2) * lda];
+ float a3 = a[offseta + i + (row + 3) * lda];
+ c[offsetc + i + col * ldc] += alphab0 * a0
+ + alphab1 * a1
+ + alphab2 * a2
+ + alphab3 * a3;
+ float bicol = b[offsetb + i + col * ldb];
+ sum0 += bicol * a0;
+ sum1 += bicol * a1;
+ sum2 += bicol * a2;
+ sum3 += bicol * a3;
+ }
+ if (beta != 0.0f) {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
+ } else {
+ c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
+ c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
+ c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
+ c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
+ }
+ }
+ }
+ }
+
+ protected void ssymmRU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ // C := alpha*B*A + beta*C
+ org.netlib.blas.Ssymm.ssymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void ssymmRL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ // C := alpha*B*A + beta*C
+ org.netlib.blas.Ssymm.ssymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ if (alpha == 0.0) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ } else if (lsame("U", uplo)) {
+ dsymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("L", uplo)) {
+ dsymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void dsymvU(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ double sumiy0 = 0.0;
+ double sumiy1 = 0.0;
+ double sumiy2 = 0.0;
+ double sumiy3 = 0.0;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ double a0 = a[offseta + row + (col + 0) * lda];
+ double a1 = a[offseta + row + (col + 1) * lda];
+ double a2 = a[offseta + row + (col + 2) * lda];
+ double a3 = a[offseta + row + (col + 3) * lda];
+ y[offsety + jy] += alphaxix0 * a0 + alphaxix1 * a1 + alphaxix2 * a2 + alphaxix3 * a3;
+ double x0 = x[offsetx + jx];
+ sumiy0 += x0 * a0;
+ sumiy1 += x0 * a1;
+ sumiy2 += x0 * a2;
+ sumiy3 += x0 * a3;
+ }
+ double a00 = a[offseta + (row + 0) + (col + 0) * lda];
+ double a01 = a[offseta + (row + 0) + (col + 1) * lda];
+ double a02 = a[offseta + (row + 0) + (col + 2) * lda];
+ double a03 = a[offseta + (row + 0) + (col + 3) * lda];
+ double a11 = a[offseta + (row + 1) + (col + 1) * lda];
+ double a12 = a[offseta + (row + 1) + (col + 2) * lda];
+ double a13 = a[offseta + (row + 1) + (col + 3) * lda];
+ double a22 = a[offseta + (row + 2) + (col + 2) * lda];
+ double a23 = a[offseta + (row + 2) + (col + 3) * lda];
+ double a33 = a[offseta + (row + 3) + (col + 3) * lda];
+ double xjx0 = x[offsetx + jx + incx * 0];
+ double xjx1 = x[offsetx + jx + incx * 1];
+ double xjx2 = x[offsetx + jx + incx * 2];
+ double xjx3 = x[offsetx + jx + incx * 3];
+ sumiy0 += xjx0 * a00
+ + xjx1 * a01
+ + xjx2 * a02
+ + xjx3 * a03;
+ sumiy1 += xjx0 * a01
+ + xjx1 * a11
+ + xjx2 * a12
+ + xjx3 * a13;
+ sumiy2 += xjx0 * a02
+ + xjx1 * a12
+ + xjx2 * a22
+ + xjx3 * a23;
+ sumiy3 += xjx0 * a03
+ + xjx1 * a13
+ + xjx2 * a23
+ + xjx3 * a33;
+ if (beta != 0.0) {
+ y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sumiy0;
+ y[offsety + iy + incy * 1] = alpha * sumiy1;
+ y[offsety + iy + incy * 2] = alpha * sumiy2;
+ y[offsety + iy + incy * 3] = alpha * sumiy3;
+ }
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ double alphaxix = alpha * x[offsetx + ix];
+ double sumiy = 0.0;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ double a0 = a[offseta + row + col * lda];
+ y[offsety + jy] += alphaxix * a0;
+ sumiy += x[offsetx + jx] * a0;
+ }
+ sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
+ if (beta != 0.0) {
+ y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sumiy;
+ }
+ }
+ }
+
+ protected void dsymvL(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
+ // y = beta * y
+ if (beta != 1.0) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0;
+ }
+ }
+ }
+ // y += alpha * A * x
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ double sumiy0 = 0.0;
+ double sumiy1 = 0.0;
+ double sumiy2 = 0.0;
+ double sumiy3 = 0.0;
+ double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda];
+ double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda];
+ double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda];
+ double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda];
+ double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda];
+ double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda];
+ double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda];
+ double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda];
+ double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda];
+ double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda];
+ double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
+ double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
+ double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
+ double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
+ sumiy0 += x0 * a00
+ + x1 * a10
+ + x2 * a20
+ + x3 * a30;
+ sumiy1 += x0 * a10
+ + x1 * a11
+ + x2 * a21
+ + x3 * a31;
+ sumiy2 += x0 * a20
+ + x1 * a21
+ + x2 * a22
+ + x3 * a32;
+ sumiy3 += x0 * a30
+ + x1 * a31
+ + x2 * a32
+ + x3 * a33;
+ int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ double a0 = a[offseta + row + (col + 0) * lda];
+ double a1 = a[offseta + row + (col + 1) * lda];
+ double a2 = a[offseta + row + (col + 2) * lda];
+ double a3 = a[offseta + row + (col + 3) * lda];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ double xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ y[offsety + iy + incy * 0] += alpha * sumiy0;
+ y[offsety + iy + incy * 1] += alpha * sumiy1;
+ y[offsety + iy + incy * 2] += alpha * sumiy2;
+ y[offsety + iy + incy * 3] += alpha * sumiy3;
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ double alphaxix = alpha * x[offsetx + ix];
+ double sumiy = 0.0;
+ sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda];
+ int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
+ }
+ y[offsety + iy] += alpha * sumiy;
+ }
+ }
+
+ protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ if (alpha == 0.0f) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ } else if (lsame("U", uplo)) {
+ ssymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ } else if (lsame("L", uplo)) {
+ ssymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
+ }
+ }
+
+ protected void ssymvU(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ float sumiy0 = 0.0f;
+ float sumiy1 = 0.0f;
+ float sumiy2 = 0.0f;
+ float sumiy3 = 0.0f;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix0 * a[offseta + row + (col + 0) * lda]
+ + alphaxix1 * a[offseta + row + (col + 1) * lda]
+ + alphaxix2 * a[offseta + row + (col + 2) * lda]
+ + alphaxix3 * a[offseta + row + (col + 3) * lda];
+ float xjx = x[offsetx + jx];
+ sumiy0 += xjx * a[offseta + row + (col + 0) * lda];
+ sumiy1 += xjx * a[offseta + row + (col + 1) * lda];
+ sumiy2 += xjx * a[offseta + row + (col + 2) * lda];
+ sumiy3 += xjx * a[offseta + row + (col + 3) * lda];
+ }
+ float a00 = a[offseta + (row + 0) + (col + 0) * lda];
+ float a01 = a[offseta + (row + 0) + (col + 1) * lda];
+ float a02 = a[offseta + (row + 0) + (col + 2) * lda];
+ float a03 = a[offseta + (row + 0) + (col + 3) * lda];
+ float a11 = a[offseta + (row + 1) + (col + 1) * lda];
+ float a12 = a[offseta + (row + 1) + (col + 2) * lda];
+ float a13 = a[offseta + (row + 1) + (col + 3) * lda];
+ float a22 = a[offseta + (row + 2) + (col + 2) * lda];
+ float a23 = a[offseta + (row + 2) + (col + 3) * lda];
+ float a33 = a[offseta + (row + 3) + (col + 3) * lda];
+ float xjx0 = x[offsetx + jx + incx * 0];
+ float xjx1 = x[offsetx + jx + incx * 1];
+ float xjx2 = x[offsetx + jx + incx * 2];
+ float xjx3 = x[offsetx + jx + incx * 3];
+ sumiy0 += xjx0 * a00
+ + xjx1 * a01
+ + xjx2 * a02
+ + xjx3 * a03;
+ sumiy1 += xjx0 * a01
+ + xjx1 * a11
+ + xjx2 * a12
+ + xjx3 * a13;
+ sumiy2 += xjx0 * a02
+ + xjx1 * a12
+ + xjx2 * a22
+ + xjx3 * a23;
+ sumiy3 += xjx0 * a03
+ + xjx1 * a13
+ + xjx2 * a23
+ + xjx3 * a33;
+ if (beta != 0.0f) {
+ y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
+ y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
+ y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
+ y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
+ } else {
+ y[offsety + iy + incy * 0] = alpha * sumiy0;
+ y[offsety + iy + incy * 1] = alpha * sumiy1;
+ y[offsety + iy + incy * 2] = alpha * sumiy2;
+ y[offsety + iy + incy * 3] = alpha * sumiy3;
+ }
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ float alphaxix = alpha * x[offsetx + ix];
+ float sumiy = 0.0f;
+ int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
+ for (; row < col; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
+ }
+ sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
+ if (beta != 0.0f) {
+ y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = alpha * sumiy;
+ }
+ }
+ }
+
+ protected void ssymvL(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
+ // y = beta * y
+ if (beta != 1.0f) {
+ for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
+ if (beta != 0.0f) {
+ y[offsety + iy] = beta * y[offsety + iy];
+ } else {
+ y[offsety + iy] = 0.0f;
+ }
+ }
+ }
+ // y += alpha * A * x
+ int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
+ for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
+ float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
+ float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
+ float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
+ float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
+ float sumiy0 = 0.0f;
+ float sumiy1 = 0.0f;
+ float sumiy2 = 0.0f;
+ float sumiy3 = 0.0f;
+ float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda];
+ float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda];
+ float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda];
+ float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda];
+ float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda];
+ float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda];
+ float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda];
+ float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda];
+ float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda];
+ float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda];
+ float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
+ float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
+ float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
+ float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
+ sumiy0 += x0 * a00
+ + x1 * a10
+ + x2 * a20
+ + x3 * a30;
+ sumiy1 += x0 * a10
+ + x1 * a11
+ + x2 * a21
+ + x3 * a31;
+ sumiy2 += x0 * a20
+ + x1 * a21
+ + x2 * a22
+ + x3 * a32;
+ sumiy3 += x0 * a30
+ + x1 * a31
+ + x2 * a32
+ + x3 * a33;
+ int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ float a0 = a[offseta + row + (col + 0) * lda];
+ float a1 = a[offseta + row + (col + 1) * lda];
+ float a2 = a[offseta + row + (col + 2) * lda];
+ float a3 = a[offseta + row + (col + 3) * lda];
+ y[offsety + jy] += alphaxix0 * a0
+ + alphaxix1 * a1
+ + alphaxix2 * a2
+ + alphaxix3 * a3;
+ float xjx = x[offsetx + jx];
+ sumiy0 += xjx * a0;
+ sumiy1 += xjx * a1;
+ sumiy2 += xjx * a2;
+ sumiy3 += xjx * a3;
+ }
+ y[offsety + iy + incy * 0] += alpha * sumiy0;
+ y[offsety + iy + incy * 1] += alpha * sumiy1;
+ y[offsety + iy + incy * 2] += alpha * sumiy2;
+ y[offsety + iy + incy * 3] += alpha * sumiy3;
+ }
+ for (; col < n; col += 1, ix += incx, iy += incy) {
+ float alphaxix = alpha * x[offsetx + ix];
+ float sumiy = 0.0f;
+ sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda];
+ int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
+ for (; row < n; row += 1, jx += incx, jy += incy) {
+ y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
+ sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
+ }
+ y[offsety + iy] += alpha * sumiy;
+ }
+ }
+
+ protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
+ org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+
+ protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
+ org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
+ }
+
+ protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
+ org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+
+ protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
+ org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
+ }
+
+ protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
+ }
+
+ protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
+ org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+
+ protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
+ org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
+ }
+
+ protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
+ }
+
+ protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
+ org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
+ org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
+ }
+
+ protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
+ org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
+ org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
+ }
+
+ protected int idamaxK(int n, double[] x, int offsetx, int incx) {
+ return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx);
+ }
+
+ protected int isamaxK(int n, float[] x, int offsetx, int incx) {
+ return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx);
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/Node.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000..cd08785
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.tree.impl.BinnedFeature
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+ // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
+ // code into the new API and deprecate the old API. SPARK-3727
+
+ /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
+ def prediction: Double
+
+ /** Impurity measure at this node (for training data) */
+ def impurity: Double
+
+ /**
+ * Statistics aggregated from training data at this node, used to compute prediction, impurity,
+ * and probabilities.
+ * For classification, the array of class counts must be normalized to a probability distribution.
+ */
+ private[ml] def impurityStats: ImpurityCalculator
+
+ /** Recursive prediction helper method */
+ private[ml] def predictImpl(features: Vector): LeafNode
+
+ /** Recursive prediction helper method */
+ private[ml] def predictBinned(binned: Array[Int], splits: Array[Array[Split]]): LeafNode
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
+ */
+ private[tree] def subtreeDepth: Int
+
+ /**
+ * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+ * @param id Node ID using old format IDs
+ */
+ private[ml] def toOld(id: Int): OldNode
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int
+
+ /** Returns a deep copy of the subtree rooted at this node. */
+ private[tree] def deepCopy(): Node
+}
+
+private[ml] object Node {
+
+ /**
+ * Create a new Node from the old Node format, recursively creating child nodes as needed.
+ */
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ if (oldNode.isLeaf) {
+ // TODO: Once the implementation has been moved to this API, then include sufficient
+ // statistics here.
+ new LeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ } else {
+ val gain = if (oldNode.stats.nonEmpty) {
+ oldNode.stats.get.gain
+ } else {
+ 0.0
+ }
+ new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+ gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
+ }
+ }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction Prediction this node makes
+ * @param impurity Impurity measure at this node (for training data)
+ */
+class LeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+
+ override def toString: String =
+ s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+ override private[ml] def predictImpl(features: Vector): LeafNode = this
+
+ override private[ml] def predictBinned(
+ binned: Array[Int],
+ splits: Array[Array[Split]]): LeafNode = this
+
+ override private[tree] def numDescendants: Int = 0
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"Predict: $prediction\n"
+ }
+
+ override private[tree] def subtreeDepth: Int = 0
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
+ impurity, isLeaf = true, None, None, None, None)
+ }
+
+ override private[ml] def maxSplitFeatureIndex(): Int = -1
+
+ override private[tree] def deepCopy(): Node = {
+ new LeafNode(prediction, impurity, impurityStats)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction Prediction this node would make if it were a leaf node
+ * @param impurity Impurity measure at this node (for training data)
+ * @param gain Information gain value. Values less than 0 indicate missing values;
+ * this quirk will be removed with future updates.
+ * @param leftChild Left-hand child node
+ * @param rightChild Right-hand child node
+ * @param split Information about the test used to split to the left or right child.
+ */
+class InternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ val gain: Double,
+ val leftChild: Node,
+ val rightChild: Node,
+ val split: Split,
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+
+ // Note to developers: The constructor argument impurityStats should be reconsidered before we
+ // make the constructor public. We may be able to improve the representation.
+
+ override def toString: String = {
+ s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+ }
+
+ override private[ml] def predictImpl(features: Vector): LeafNode = {
+ var node: Node = this
+ while (node.isInstanceOf[InternalNode]) {
+ val n = node.asInstanceOf[InternalNode]
+ if (n.split.shouldGoLeft(features)) {
+ node = n.leftChild
+ } else {
+ node = n.rightChild
+ }
+ }
+ node.asInstanceOf[LeafNode]
+ }
+
+ override private[ml] def predictBinned(
+ binned: Array[Int],
+ splits: Array[Array[Split]]): LeafNode = {
+ var node: Node = this
+ while (node.isInstanceOf[InternalNode]) {
+ val n = node.asInstanceOf[InternalNode]
+ val i = n.split.featureIndex
+ if (n.split.shouldGoLeft(binned(i), splits(i))) {
+ node = n.leftChild
+ } else {
+ node = n.rightChild
+ }
+ }
+ node.asInstanceOf[LeafNode]
+ }
+
+ override private[tree] def numDescendants: Int = {
+ 2 + leftChild.numDescendants + rightChild.numDescendants
+ }
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" +
+ leftChild.subtreeToString(indentFactor + 1) +
+ prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" +
+ rightChild.subtreeToString(indentFactor + 1)
+ }
+
+ override private[tree] def subtreeDepth: Int = {
+ 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+ }
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ + " since the old API does not support deep trees.")
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
+ isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+ Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+ new OldPredict(leftChild.prediction, prob = 0.0),
+ new OldPredict(rightChild.prediction, prob = 0.0))))
+ }
+
+ override private[ml] def maxSplitFeatureIndex(): Int = {
+ math.max(split.featureIndex,
+ math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
+ }
+
+ override private[tree] def deepCopy(): Node = {
+ new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
+ split, impurityStats)
+ }
+}
+
+private object InternalNode {
+
+ /**
+ * Helper method for [[Node.subtreeToString()]].
+ * @param split Split to print
+ * @param left Indicates whether this is the part of the split going to the left,
+ * or that going to the right.
+ */
+ private def splitToString(split: Split, left: Boolean): String = {
+ val featureStr = s"feature ${split.featureIndex}"
+ split match {
+ case contSplit: ContinuousSplit =>
+ if (left) {
+ s"$featureStr <= ${contSplit.threshold}"
+ } else {
+ s"$featureStr > ${contSplit.threshold}"
+ }
+ case catSplit: CategoricalSplit =>
+ val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
+ if (left) {
+ s"$featureStr in $categoriesStr"
+ } else {
+ s"$featureStr not in $categoriesStr"
+ }
+ }
+ }
+}
+
+/**
+ * Version of a node used in learning. This uses vars so that we can modify nodes as we split the
+ * tree by adding children, etc.
+ *
+ * For now, we use node IDs. These will be kept internal since we hope to remove node IDs
+ * in the future, or at least change the indexing (so that we can support much deeper trees).
+ *
+ * This node can either be:
+ * - a leaf node, with leftChild, rightChild, split set to null, or
+ * - an internal node, with all values set
+ *
+ * @param id We currently use the same indexing as the old implementation in
+ * [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
+ * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
+ * so that we do not need to consider splitting it further.
+ * @param stats Impurity statistics for this node.
+ */
+private[tree] class LearningNode(
+ var id: Int,
+ var leftChild: Option[LearningNode],
+ var rightChild: Option[LearningNode],
+ var split: Option[Split],
+ var isLeaf: Boolean,
+ var stats: ImpurityStats) extends Serializable {
+
+ def toNode: Node = toNode(prune = true)
+
+ /**
+ * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
+ */
+ def toNode(prune: Boolean = true): Node = {
+
+ if (!leftChild.isEmpty || !rightChild.isEmpty) {
+ assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
+ "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
+ (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
+ case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
+ new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
+ case (l, r) =>
+ new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
+ l, r, split.get, stats.impurityCalculator)
+ }
+ } else {
+ if (stats.valid) {
+ new LeafNode(stats.impurityCalculator.predict, stats.impurity,
+ stats.impurityCalculator)
+ } else {
+ // Here we want to keep same behavior with the old mllib.DecisionTreeModel
+ new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
+ }
+ }
+ }
+
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
+ *
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to
+ * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
+ */
+ def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
+ var node = this
+ while (!node.isLeaf && node.split.nonEmpty) {
+ val split = node.split.get
+ val featureIndex = split.featureIndex
+ val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+ if (node.leftChild.isEmpty) {
+ // Not yet split. Return next layer of nodes to train
+ if (splitLeft) {
+ return LearningNode.leftChildIndex(node.id)
+ } else {
+ return LearningNode.rightChildIndex(node.id)
+ }
+ } else {
+ if (splitLeft) {
+ node = node.leftChild.get
+ } else {
+ node = node.rightChild.get
+ }
+ }
+ }
+ node.id
+ }
+
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
+ *
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to
+ * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
+ */
+ def predictImpl(binnedFeatures: BinnedFeature, splits: Array[Array[Split]]): Int = {
+ var node = this
+ while (!node.isLeaf && node.split.nonEmpty) {
+ val split = node.split.get
+ val featureIndex = split.featureIndex
+ val splitLeft = split.shouldGoLeft(binnedFeatures.get(featureIndex), splits(featureIndex))
+ if (node.leftChild.isEmpty) {
+ // Not yet split. Return next layer of nodes to train
+ if (splitLeft) {
+ return LearningNode.leftChildIndex(node.id)
+ } else {
+ return LearningNode.rightChildIndex(node.id)
+ }
+ } else {
+ if (splitLeft) {
+ node = node.leftChild.get
+ } else {
+ node = node.rightChild.get
+ }
+ }
+ }
+ node.id
+ }
+
+}
+
+private[tree] object LearningNode {
+
+ /** Create a node with some of its fields set. */
+ def apply(
+ id: Int,
+ isLeaf: Boolean,
+ stats: ImpurityStats): LearningNode = {
+ new LearningNode(id, None, None, None, false, stats)
+ }
+
+ /** Create an empty node with the given node index. Values must be set later on. */
+ def emptyNode(nodeIndex: Int): LearningNode = {
+ new LearningNode(nodeIndex, None, None, None, false, null)
+ }
+
+ // The below indexing methods were copied from spark.mllib.tree.model.Node
+
+ /**
+ * Return the index of the left child of this node.
+ */
+ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+ /**
+ * Return the index of the right child of this node.
+ */
+ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+ /**
+ * Get the parent index of the given node, or 0 if it is the root.
+ */
+ def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+ /**
+ * Return the level of a tree which the given node is in.
+ */
+ def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+ throw new IllegalArgumentException(s"0 is not a valid node index.")
+ } else {
+ java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+ }
+
+ /**
+ * Returns true if this is a left child.
+ * Note: Returns false for the root.
+ */
+ def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+ /**
+ * Return the maximum number of nodes which can be in the given level of the tree.
+ * @param level Level of tree (0 = root).
+ */
+ def maxNodesInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Return the index of the first node in the given level.
+ * @param level Level of tree (0 = root).
+ */
+ def startIndexInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Traces down from a root node to get the node with the given node index.
+ * This assumes the node exists.
+ */
+ def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
+ var tmpNode: LearningNode = rootNode
+ var levelsToGo = indexToLevel(nodeIndex)
+ while (levelsToGo > 0) {
+ if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+ tmpNode = tmpNode.leftChild.get
+ } else {
+ tmpNode = tmpNode.rightChild.get
+ }
+ levelsToGo -= 1
+ }
+ tmpNode
+ }
+
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
new file mode 100644
index 0000000..0819fe9
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.commons.math3.distribution.PoissonDistribution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
+ * particularly for bagging (e.g., for random forests).
+ *
+ * This holds one instance, as well as an array of weights which represent the (weighted)
+ * number of times which this instance appears in each subsamplingRate.
+ * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
+ * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
+ *
+ * @param datum Data instance
+ * @param subsampleCounts Number of samples of this instance in each subsampled dataset.
+ * @param sampleWeight The weight of this instance.
+ * @param sampleId ID of sample
+ */
+private[spark] class BaggedPoint[Datum](
+ val datum: Datum,
+ val subsampleCounts: Array[Int],
+ val sampleWeight: Double = 1.0,
+ var sampleId: Short = 0) extends Serializable
+
+private[spark] object BaggedPoint {
+
+ /**
+ * Convert an input dataset into its BaggedPoint representation,
+ * choosing subsamplingRate counts for each instance.
+ * Each subsamplingRate has the same number of instances as the original dataset,
+ * and is created by subsampling without replacement.
+ * @param input Input dataset.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param withReplacement Sampling with/without replacement.
+ * @param extractSampleWeight A function to get the sample weight of each datum.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation.
+ */
+ def convertToBaggedRDD[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ withReplacement: Boolean,
+ extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
+ seed: Long = Utils.random.nextLong(),
+ oneFeaturePerTree: Boolean = false): RDD[BaggedPoint[Datum]] = {
+ if (oneFeaturePerTree) {
+ convertToBaggedRDDWithoutSampling(input, 1, extractSampleWeight)
+ } else {
+ // TODO: implement weighted bootstrapping
+ if (withReplacement) {
+ convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples,
+ extractSampleWeight, seed)
+ } else if (subsamplingRate == 1.0) {
+ convertToBaggedRDDWithoutSampling(input, numSubsamples, extractSampleWeight)
+ } else {
+ convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples,
+ extractSampleWeight, seed)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ extractSampleWeight: (Datum => Double),
+ seed: Long): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val rng = new XORShiftRandom
+ rng.setSeed(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleCounts = new Array[Int](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ if (rng.nextDouble() < subsamplingRate) {
+ subsampleCounts(subsampleIndex) = 1
+ }
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+ input: RDD[Datum],
+ subsample: Double,
+ numSubsamples: Int,
+ extractSampleWeight: (Datum => Double),
+ seed: Long): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val poisson = new PoissonDistribution(subsample)
+ poisson.reseedRandomGenerator(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleCounts = new Array[Int](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ subsampleCounts(subsampleIndex) = poisson.sample()
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
+ }
+ }
+ }
+
+ private def convertToBaggedRDDWithoutSampling[Datum] (
+ input: RDD[Datum],
+ numSubsamples: Int,
+ extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitions { instances =>
+ val subsampleCounts = Array.fill(numSubsamples)(1)
+ instances.map { instance =>
+ new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
+ }
+ }
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DTFeatureStatsAggregator.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DTFeatureStatsAggregator.scala
new file mode 100644
index 0000000..da27e06
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DTFeatureStatsAggregator.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.mllib.tree.impurity._
+
+
+/**
+ * DecisionTree statistics aggregator for a feature for a node.
+ * This class is abstract to support learning with and without feature subsampling.
+ */
+private[spark] class DTFeatureStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ val _featureIndex: Int) extends Serializable {
+
+ /**
+ * [[ImpurityAggregator]] instance specifying the impurity type.
+ */
+
+ val impurityAggregator = new VarianceAggregator()
+
+ val featureIndex: Int = _featureIndex
+
+ /**
+ * Number of elements (Double values) used for the sufficient statistics of each bin.
+ */
+ private val statsSize: Int = impurityAggregator.statsSize
+
+ /**
+ * Number of bins for the feature.
+ */
+ private val numBins: Int = {
+ metadata.numBins(featureIndex)
+ }
+
+ /**
+ * Total number of elements stored in this aggregator
+ */
+ private val allStatsSize: Int = numBins * statsSize
+
+ /**
+ * Flat array of elements.
+ */
+ private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+ /**
+ * Array of parent node sufficient stats.
+ */
+ private val parentStats: Array[Double] = new Array[Double](statsSize)
+
+ /**
+ * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ */
+ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, binIndex * statsSize)
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for the parent node.
+ */
+ def getParentImpurityCalculator(): ImpurityCalculator = {
+ impurityAggregator.getCalculator(parentStats, 0)
+ }
+
+ /**
+ * Update the stats for a given bin for ordered features, using the given label.
+ */
+ def updateX(featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ sampleCount: Int,
+ weight: Double): Unit = {
+ val i = binIndex * statsSize
+ impurityAggregator.update(allStats, i, label, sampleCount, weight)
+ }
+
+ /**
+ * Pre-compute feature offset for use with [[featureUpdate]].
+ * For ordered features only.
+ */
+ def getFeatureOffset(featureIndex: Int): Int = 0
+
+ /**
+ * For a given feature, merge the stats for two bins.
+ *
+ * @param featureOffset This is a pre-computed feature offset
+ * from [[getFeatureOffset]].
+ * @param binIndex The other bin is merged into this bin.
+ * @param otherBinIndex This bin is not modified. X
+ */
+ def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, binIndex * statsSize, otherBinIndex * statsSize)
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
new file mode 100644
index 0000000..d602797
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import it.unimi.dsi.fastutil.objects.ObjectArrayList
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, LearningNode, Split}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.ImpurityStats
+
+object GradientBoostedTreesCore extends Logging{
+ private[tree] class NodeIndexInfo(
+ val nodeIndexInGroup: Int,
+ val featureSubset: Option[Array[Int]],
+ val featureSubsetHashSetX: Option[scala.collection.mutable.HashSet[Int]] = None)
+ extends Serializable
+
+ /**
+ * Calculate the impurity statistics for a given (feature, split) based upon left/right
+ * aggregates.
+ *
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
+ */
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ val totalCount = leftCount + rightCount
+
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ }
+
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
+ }
+
+ /**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private[tree] def binsToBestSplitX(
+ binAggregates: DTFeatureStatsAggregator,
+ splits: ObjectArrayList[Split],
+ featureIndex: Int,
+ node: LearningNode): (Split, ImpurityStats) = {
+
+ // Calculate InformationGain and ImpurityStats if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
+ null
+ } else {
+ node.stats
+ }
+
+ if (binAggregates.metadata.numSplits(featureIndex) != 0) {
+ val featureIndexIdx = featureIndex
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(0, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { case splitIdx =>
+ val leftChildStats = binAggregates.getImpurityCalculator(0, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(0, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits.get(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // unreachable for GBDT
+ // Unordered categorical feature
+ // val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(0, splitIndex)
+ val rightChildStats = binAggregates.getImpurityCalculator(0, numSplits)
+ .subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ (splits.get(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature, reachable for GBDT if not continuous
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numCategories).map { case featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(0, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // unreachable for GBDT
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // unreachable for GBDT
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+
+ logDebug(s"Centroids for categorical variable: ${centroidForCategories.mkString(",")}")
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug("Sorted centroids for categorical variable = " +
+ categoriesSortedByCentroid.mkString(","))
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(0, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(0, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(0, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ } else {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ // val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ // No split, no need to merge
+ val featureIndexIdx = featureIndex
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ val parentImpurityCalculator = binAggregates.getImpurityCalculator(0, numSplits)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ (new ContinuousSplit(featureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ // Seems like unreachable for GBDT (as well as RF)
+ val numCategories = binAggregates.metadata.featureArity(featureIndex)
+ (new CategoricalSplit(featureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointX.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointX.scala
index 0b8392a..ec818d9 100644
--- a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointX.scala
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointX.scala
@@ -19,12 +19,11 @@ package org.apache.spark.ml.tree.impl
import it.unimi.dsi.fastutil.ints.Int2CharOpenHashMap
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.ml.tree.impl.BinnedFeaturesDataType.BinnedFeaturesDataType
import org.apache.spark.rdd.RDD
-
/**
* Enum for selecting the data type of binned features of training samples
*/
@@ -48,10 +47,12 @@ object BinnedFeaturesDataType extends Enumeration {
* @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
+ * @param weight Sample weight for this TreePointX.
*/
-private[spark] class TreePointX(val label: Double, val binnedFeatures: BinnedFeature)
- extends Serializable {
-}
+private[spark] class TreePointX(
+ val label: Double,
+ val binnedFeatures: BinnedFeature,
+ val weight: Double) extends Serializable
private[spark] object TreePointX {
@@ -64,7 +65,7 @@ private[spark] object TreePointX {
* @return TreePointX dataset representation
*/
def convertToTreeRDD(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePointX] = {
convertToTreeRDD(input, splits, metadata, BinnedFeaturesDataType.array)
@@ -79,7 +80,7 @@ private[spark] object TreePointX {
* @return TreePointX dataset representation
*/
def convertToTreeRDD(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata,
binnedFeaturesType: BinnedFeaturesDataType): RDD[TreePointX] = {
@@ -94,7 +95,7 @@ private[spark] object TreePointX {
if (arity == 0) {
splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
} else {
- Array.empty[Double]
+ Array.emptyDoubleArray
}
}
val useArrayType = (binnedFeaturesType == BinnedFeaturesDataType.array)
@@ -107,6 +108,7 @@ private[spark] object TreePointX {
TreePointX.labeledPointToTreePointByFastHashMap(x, thresholds, featureArity)
}
}
+
}
/**
@@ -117,19 +119,19 @@ private[spark] object TreePointX {
* for categorical features.
*/
private[spark] def labeledPointToTreePointByArray(
- labeledPoint: LabeledPoint,
+ instance: Instance,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePointX = {
- val numFeatures = labeledPoint.features.size
+ val numFeatures = instance.features.size
val arr = new Array[Char](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+ findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
.toChar
featureIndex += 1
}
- new TreePointX(labeledPoint.label, new BinnedFeatureArray(arr))
+ new TreePointX(instance.label, new BinnedFeatureArray(arr), instance.weight)
}
/**
@@ -140,15 +142,15 @@ private[spark] object TreePointX {
* for categorical features.
*/
private[spark] def labeledPointToTreePointByFastHashMap(
- labeledPoint: LabeledPoint,
+ instance: Instance,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePointX = {
- val numFeatures = labeledPoint.features.size
+ val numFeatures = instance.features.size
val binFeaturesMap = new Int2CharOpenHashMap()
var featureIndex = 0
while (featureIndex < numFeatures) {
val binFeature =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+ findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
.toChar
if (binFeature != '\u0000') {
binFeaturesMap.put(featureIndex, binFeature)
@@ -156,7 +158,7 @@ private[spark] object TreePointX {
featureIndex += 1
}
val binFeatures = new BinnedFeatureFastHashMap(binFeaturesMap)
- new TreePointX(labeledPoint.label, binFeatures)
+ new TreePointX(instance.label, binFeatures, instance.weight)
}
/**
@@ -169,10 +171,10 @@ private[spark] object TreePointX {
*/
private def findBin(
featureIndex: Int,
- labeledPoint: LabeledPoint,
+ instance: Instance,
featureArity: Int,
thresholds: Array[Double]): Int = {
- val featureValue = labeledPoint.features(featureIndex)
+ val featureValue = instance.features(featureIndex)
if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@@ -188,7 +190,7 @@ private[spark] object TreePointX {
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
- s" Bad data point: ${labeledPoint.toString}")
+ s" Bad data point: $instance")
}
featureValue.toInt
}
@@ -220,4 +222,3 @@ private[spark] class BinnedFeatureFastHashMap(val featureMap: Int2CharOpenHashMa
extends Serializable with BinnedFeature {
override def get(index: Int): Char = featureMap.get(index)
}
-
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointY.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointY.scala
index 4272edb..4e0f7a4 100644
--- a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointY.scala
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/TreePointY.scala
@@ -1,189 +1,137 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.tree.impl
-
-import it.unimi.dsi.fastutil.ints.Int2CharOpenHashMap
-
-import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.tree.{ContinuousSplit, Split}
-import org.apache.spark.ml.tree.impl.BinnedFeaturesDataType.BinnedFeaturesDataType
-import org.apache.spark.rdd.RDD
-
-/**
- * Internal representation of LabeledPoint for DecisionTree.
- * This bins feature values based on a subsampled of data as follows:
- * (a) Continuous features are binned into ranges.
- * (b) Unordered categorical features are binned based on subsets of feature values.
- * "Unordered categorical features" are categorical features with low arity used in
- * multiclass classification.
- * (c) Ordered categorical features are binned based on feature values.
- * "Ordered categorical features" are categorical features with high arity,
- * or any categorical feature used in regression or binary classification.
- *
- * @param label Label from LabeledPoint
- * @param binnedFeatures Binned feature values.
- * Same length as LabeledPoint.features, but values are bin indices.
- */
-private[spark] class TreePointY(val label: Double, val binnedFeatures: BinnedFeature,
- val uniqueID: Char = '\u0000')
- extends Serializable {
-}
-
-private[spark] object TreePointY {
-
- /**
- * Convert an input dataset into its TreePointY representation,
- * binning feature values in preparation for DecisionTree training.
- * @param input Input dataset.
- * @param splits Splits for features, of size (numFeatures, numSplits).
- * @param metadata Learning and dataset metadata
- * @return TreePointY dataset representation
- */
- def convertToTreeRDD(
- input: RDD[LabeledPoint],
- splits: Array[Array[Split]],
- metadata: DecisionTreeMetadata): RDD[TreePointY] = {
- convertToTreeRDD(input, splits, metadata, BinnedFeaturesDataType.array)
- }
-
- /**
- * Convert an input dataset into its TreePointY representation,
- * binning feature values in preparation for DecisionTree training.
- * @param input Input dataset.
- * @param splits Splits for features, of size (numFeatures, numSplits).
- * @param metadata Learning and dataset metadata
- * @return TreePointY dataset representation
- */
- def convertToTreeRDD(
- input: RDD[LabeledPoint],
- splits: Array[Array[Split]],
- metadata: DecisionTreeMetadata,
- binnedFeaturesType: BinnedFeaturesDataType): RDD[TreePointY] = {
- // Construct arrays for featureArity for efficiency in the inner loop.
- val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- var featureIndex = 0
- while (featureIndex < metadata.numFeatures) {
- featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- featureIndex += 1
- }
- val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
- if (arity == 0) {
- splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
- } else {
- Array.empty[Double]
- }
- }
- val useArrayType = (binnedFeaturesType == BinnedFeaturesDataType.array)
- if (useArrayType) {
- input.zipWithUniqueId.map { case(x, id) =>
- TreePointY.labeledPointToTreePointByArray(x, thresholds, featureArity, id)
- }
- } else {
- input.zipWithUniqueId.map { case(x, id) =>
- TreePointY.labeledPointToTreePointByFastHashMap(x, thresholds, featureArity, id)
- }
- }
- }
-
- /**
- * Convert one LabeledPoint into its TreePointY representation.
- * @param thresholds For each feature, split thresholds for continuous features,
- * empty for categorical features.
- * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
- * for categorical features.
- */
- private[spark] def labeledPointToTreePointByArray(
- labeledPoint: LabeledPoint,
- thresholds: Array[Array[Double]],
- featureArity: Array[Int],
- id: Long = 0): TreePointY = {
- val numFeatures = labeledPoint.features.size
- val arr = new Array[Char](numFeatures)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- arr(featureIndex) =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
- .toChar
- featureIndex += 1
- }
- new TreePointY(labeledPoint.label, new BinnedFeatureArray(arr), (id % Char.MaxValue).toChar)
- }
-
- /**
- * Convert one LabeledPoint into its TreePointY representation.
- * @param thresholds For each feature, split thresholds for continuous features,
- * empty for categorical features.
- * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
- * for categorical features.
- */
- private[spark] def labeledPointToTreePointByFastHashMap(
- labeledPoint: LabeledPoint,
- thresholds: Array[Array[Double]],
- featureArity: Array[Int],
- id: Long = 0): TreePointY = {
- val numFeatures = labeledPoint.features.size
- val binFeaturesMap = new Int2CharOpenHashMap()
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val binFeature =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
- .toChar
- if (binFeature != '\u0000') {
- binFeaturesMap.put(featureIndex, binFeature)
- }
- featureIndex += 1
- }
- val binFeatures = new BinnedFeatureFastHashMap(binFeaturesMap)
- new TreePointY(labeledPoint.label, binFeatures, (id % Char.MaxValue).toChar)
- }
-
- /**
- * Find discretized value for one (labeledPoint, feature).
- *
- * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
- * (mllib) tree API. We want to maintain the same behavior as the old tree API.
- *
- * @param featureArity 0 for continuous features; number of categories for categorical features.
- */
- private def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- featureArity: Int,
- thresholds: Array[Double]): Int = {
- val featureValue = labeledPoint.features(featureIndex)
-
- if (featureArity == 0) {
- val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
- if (idx >= 0) {
- idx
- } else {
- -idx - 1
- }
- } else {
- // Categorical feature bins are indexed by feature values.
- if (featureValue < 0 || featureValue >= featureArity) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- featureValue.toInt
- }
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.tree.{ContinuousSplit, Split}
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Internal representation of LabeledPoint for DecisionTree.
+ * This bins feature values based on a subsampled of data as follows:
+ * (a) Continuous features are binned into ranges.
+ * (b) Unordered categorical features are binned based on subsets of feature values.
+ * "Unordered categorical features" are categorical features with low arity used in
+ * multiclass classification.
+ * (c) Ordered categorical features are binned based on feature values.
+ * "Ordered categorical features" are categorical features with high arity,
+ * or any categorical feature used in regression or binary classification.
+ *
+ * @param label Label from LabeledPoint
+ * @param binnedFeatures Binned feature values.
+ * Same length as LabeledPoint.features, but values are bin indices.
+ * @param weight Sample weight for this TreePoint.
+ */
+private[spark] class TreePointY(
+ val label: Double,
+ val binnedFeatures: Array[Int],
+ val weight: Double,
+ val uniqueID: Char = '\u0000') extends Serializable
+
+private[spark] object TreePointY {
+
+ /**
+ * Convert an input dataset into its TreePoint representation,
+ * binning feature values in preparation for DecisionTree training.
+ * @param input Input dataset.
+ * @param splits Splits for features, of size (numFeatures, numSplits).
+ * @param metadata Learning and dataset metadata
+ * @return TreePoint dataset representation
+ */
+ def convertToTreeRDD(
+ input: RDD[Instance],
+ splits: Array[Array[Split]],
+ metadata: DecisionTreeMetadata): RDD[TreePointY] = {
+ // Construct arrays for featureArity for efficiency in the inner loop.
+ val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+ var featureIndex = 0
+ while (featureIndex < metadata.numFeatures) {
+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+ featureIndex += 1
+ }
+ val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
+ if (arity == 0) {
+ splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
+ } else {
+ Array.emptyDoubleArray
+ }
+ }
+ input.zipWithUniqueId.map { case(x, id) =>
+ TreePointY.labeledPointToTreePoint(x, thresholds, featureArity, id)
+ }
+ }
+
+ /**
+ * Convert one LabeledPoint into its TreePoint representation.
+ * @param thresholds For each feature, split thresholds for continuous features,
+ * empty for categorical features.
+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
+ * for categorical features.
+ */
+ private def labeledPointToTreePoint(
+ instance: Instance,
+ thresholds: Array[Array[Double]],
+ featureArity: Array[Int],
+ id: Long = 0): TreePointY = {
+ val numFeatures = instance.features.size
+ val arr = new Array[Int](numFeatures)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ arr(featureIndex) =
+ findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
+ featureIndex += 1
+ }
+ new TreePointY(instance.label, arr, instance.weight, (id % Char.MaxValue).toChar)
+ }
+
+ /**
+ * Find discretized value for one (labeledPoint, feature).
+ *
+ * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
+ * (mllib) tree API. We want to maintain the same behavior as the old tree API.
+ *
+ * @param featureArity 0 for continuous features; number of categories for categorical features.
+ */
+ private def findBin(
+ featureIndex: Int,
+ instance: Instance,
+ featureArity: Int,
+ thresholds: Array[Double]): Int = {
+ val featureValue = instance.features(featureIndex)
+
+ if (featureArity == 0) {
+ val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
+ if (idx >= 0) {
+ idx
+ } else {
+ -idx - 1
+ }
+ } else {
+ // Categorical feature bins are indexed by feature values.
+ if (featureValue < 0 || featureValue >= featureArity) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ s" Bad data point: $instance")
+ }
+ featureValue.toInt
+ }
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala b/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala
new file mode 100644
index 0000000..4d73e95
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.feature
+
+/**
+ * Entry in vocabulary
+ */
+private case class VocabWord(
+ var word: String,
+ var cn: Int,
+ var point: Array[Int],
+ var code: Array[Int],
+ var codeLen: Int
+)
diff --git a/ml-kernel-client-core/pom.xml b/ml-kernel-client-core/pom.xml
index e0e7f74..a173de6 100644
--- a/ml-kernel-client-core/pom.xml
+++ b/ml-kernel-client-core/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.1.0
+ 2.2.0
4.0.0
boostkit-ml-kernel-client-core_2.12
- 2.1.0
+ 2.2.0
${project.artifactId}
Spark ml core
diff --git a/ml-kernel-client/pom.xml b/ml-kernel-client/pom.xml
index 7434007..a56ccec 100644
--- a/ml-kernel-client/pom.xml
+++ b/ml-kernel-client/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.1.0
+ 2.2.0
4.0.0
boostkit-ml-kernel-client_2.12
- 2.1.0
+ 2.2.0
${project.artifactId}
Spark ml core
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
new file mode 100644
index 0000000..c152970
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
@@ -0,0 +1,261 @@
+// scalastyle:off header.matches
+/*
+* Copyright (C) 2022. Huawei Technologies Co., Ltd.
+* This program is distributed in the hope that it will be useful,
+* but WITHOUT ANY WARRANTY; without even the implied warranty of
+* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+* */
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.tree.configuration.{Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+
+private[ml] trait DecisionTreeBucketizerParams extends Params
+ with DecisionTreeClassifierParams with HasWeightCol {
+
+ /**
+ * Param for bucketedFeatures column name.
+ * @group param
+ */
+ final val bucketedFeaturesCol: Param[String] =
+ new Param[String](this, "bucketedFeaturesCol", "bucketedFeatures column name")
+
+ final val prune: BooleanParam =
+ new BooleanParam(this, "prune", "if true, the algorithm will prune decision trees")
+
+ setDefault(bucketedFeaturesCol, "bucketedFeatures")
+ setDefault(prune, true)
+
+ /** @group getParam */
+ final def getBucketedFeaturesCol: String = $(bucketedFeaturesCol)
+
+ /** @group getParam */
+ final def getPrune: Boolean = $(prune)
+}
+
+/**
+ * Decision tree bucketing algorithm for data discretization.
+ */
+@Since("1.4.0")
+class DecisionTreeBucketizer @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
+ extends Estimator[DecisionTreeBucketModel]
+ with DecisionTreeBucketizerParams with DecisionTreeClassifierParams with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("dtb"))
+
+ def setLabelCol(value: String): this.type = null
+
+ def setFeaturesCol(value: String): this.type = null
+
+ def setBucketedFeaturesCol(value: String): this.type = null
+
+ def setPrune(value: Boolean): this.type = null
+
+ // Override parameter setters from parent trait for Java API compatibility.
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxDepth(value: Int): this.type = null
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMaxBins(value: Int): this.type = null
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinInstancesPerNode(value: Int): this.type = null
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = null
+
+ @Since("1.4.0")
+ def setMinInfoGain(value: Double): this.type = null
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setMaxMemoryInMB(value: Int): this.type = null
+
+ /** @group expertSetParam */
+ @Since("1.4.0")
+ def setCacheNodeIds(value: Boolean): this.type = null
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be at least 1.
+ * (default = 10)
+ * @group setParam
+ */
+ @Since("1.4.0")
+ def setCheckpointInterval(value: Int): this.type = null
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setImpurity(value: String): this.type = null
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = null
+
+ /** @group setParam */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+ * such as in `extractLabeledPoints()`.
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ private[ml] def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = 0
+
+ override def transformSchema(schema: StructType): StructType = null
+
+ override def fit(dataset: Dataset[_]): DecisionTreeBucketModel = null
+
+ private[ml] def train(dataset: Dataset[_]): DecisionTreeBucketModel = null
+
+ /** (private[ml]) Train decision trees on an RDD */
+ private[ml] def train(data: RDD[LabeledPoint],
+ oldStrategy: OldStrategy): DecisionTreeBucketModel = null
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = null
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): DecisionTreeBucketizer = null
+}
+
+@Since("1.4.0")
+object DecisionTreeBucketizer extends DefaultParamsReadable[DecisionTreeBucketizer] {
+ /** Accessor for supported impurities: entropy, gini */
+ @Since("1.4.0")
+ final val supportedImpurities: Array[String] = null
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeBucketizer = null
+}
+
+/**
+ * Decision tree bucket model for data discretization.
+ * @param _trees Decision trees of all features.
+ */
+@Since("1.4.0")
+class DecisionTreeBucketModel private[ml] (
+ @Since("1.5.0") override val uid: String,
+ private val _trees: Array[DecisionTreeClassificationModel],
+ val numFeatures: Int,
+ val numClasses: Int)
+ extends Model[DecisionTreeBucketModel]
+ with DecisionTreeBucketizerParams with DecisionTreeClassifierParams
+ with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
+
+ require(_trees.nonEmpty, "DecisionTreeBucketModel requires at least 1 tree.")
+
+ /**
+ * Construct a decision tree bucket model, with all trees weighted equally.
+ *
+ * @param trees Component trees
+ */
+ private[ml] def this(
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int) =
+ this(Identifiable.randomUID("dtb"), trees, numFeatures, numClasses)
+
+ def getNumTrees: Int = 0
+
+ @Since("1.4.0")
+ override def trees: Array[DecisionTreeClassificationModel] = null
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = null
+
+ @Since("1.4.0")
+ override def treeWeights: Array[Double] = null
+
+ override def transformSchema(schema: StructType): StructType = null
+
+ override def transform(dataset: Dataset[_]): DataFrame = null
+
+ @Since("1.4.0")
+ override def copy(extra: ParamMap): DecisionTreeBucketModel = null
+
+ @Since("1.4.0")
+ override def toString: String = null
+
+ /** (private[ml]) Convert to a model in the old API */
+ def toOld: OldRandomForestModel = null
+
+ @Since("2.0.0")
+ override def write: MLWriter = null
+}
+
+@Since("2.0.0")
+object DecisionTreeBucketModel extends MLReadable[DecisionTreeBucketModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeBucketModel] = null
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeBucketModel = null
+
+ private[DecisionTreeBucketModel]
+ class DecisionTreeBucketModelWriter(instance: DecisionTreeBucketModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {}
+ }
+
+ private class DecisionTreeBucketModelReader
+ extends MLReader[DecisionTreeBucketModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeBucketModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): DecisionTreeBucketModel = null
+ }
+
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: DecisionTreeBucketizer,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ numFeatures: Int = -1): DecisionTreeBucketModel = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
new file mode 100644
index 0000000..1422e2c
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
@@ -0,0 +1,76 @@
+// scalastyle:off
+/*
+* Copyright (C) 2022. Huawei Technologies Co., Ltd.
+* This program is distributed in the hope that it will be useful,
+* but WITHOUT ANY WARRANTY; without even the implied warranty of
+* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+* */
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList
+import it.unimi.dsi.fastutil.ints.{Int2ObjectOpenHashMap, IntArrayList}
+import it.unimi.dsi.fastutil.objects.ObjectArrayList
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.tree.{LearningNode, Split}
+import org.apache.spark.ml.tree.impl.GradientBoostedTreesCore.NodeIndexInfo
+import org.apache.spark.mllib.tree.configuration.{Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.ImpurityStats
+import org.apache.spark.rdd.RDD
+
+object GradientBoostedTreesUtil extends Logging {
+
+ def dataProcessX(
+ input: RDD[Instance],
+ splits: Array[Array[Split]],
+ treeStrategy: OldStrategy,
+ metadata: DecisionTreeMetadata,
+ timer: TimeTracker,
+ seed: Long): (RDD[TreePoint], RDD[(Int, (IntArrayList, ObjectArrayList[Split]))],
+ Broadcast[DoubleArrayList], Broadcast[DoubleArrayList],
+ Broadcast[Int2ObjectOpenHashMap[IntArrayList]]) = {
+ null
+ }
+
+ def nodeIdCacheXConstruction(
+ nodes: Array[LearningNode],
+ rawPartInfoBc: Broadcast[Int2ObjectOpenHashMap[IntArrayList]])
+ : Int2ObjectOpenHashMap[Int2ObjectOpenHashMap[IntArrayList]] = {
+ null
+ }
+
+ def chooseBestSplits(
+ input: RDD[(Int, (IntArrayList, ObjectArrayList[Split]))],
+ nodeIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ metadata: DecisionTreeMetadata,
+ nodeIdCacheBc: Broadcast[Int2ObjectOpenHashMap[Int2ObjectOpenHashMap[IntArrayList]]],
+ labelArrayBc: Broadcast[DoubleArrayList],
+ nodes: Array[LearningNode],
+ sampleWeightArrayBc: Broadcast[DoubleArrayList],
+ useWeight: (Boolean, Double))
+ : scala.collection.Map[Int, (Split, ImpurityStats)] = {
+ null
+ }
+
+ def updateNodeIdCache(
+ nodeIdCache: Int2ObjectOpenHashMap[Int2ObjectOpenHashMap[IntArrayList]],
+ nodeIdCacheBc: Broadcast[Int2ObjectOpenHashMap[Int2ObjectOpenHashMap[IntArrayList]]],
+ input: RDD[TreePoint],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ nodeIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ splits: Array[Array[Split]],
+ rawPartInfoBc: Broadcast[Int2ObjectOpenHashMap[IntArrayList]],
+ metadata: DecisionTreeMetadata,
+ timer: TimeTracker): Unit = {
+ }
+
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
index ed245f9..d9f186c 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
@@ -1,6 +1,6 @@
// scalastyle:off header.matches
/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
+* Copyright (C) 2022. Huawei Technologies Co., Ltd.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
@@ -61,6 +61,11 @@ object RFUtils extends Logging {
true
}
+ def isValidSample(baggedPoint: BaggedPoint[TreePointX],
+ groupInfo: GroupInfo, treeIndex: Int, id: Short): Boolean = {
+ true
+ }
+
def isValidNodeInfo(nodeInfo: NodeIndexInfo, agg: Array[DTStatsAggregator]): Boolean = {
true
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
new file mode 100644
index 0000000..e69cc05
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
@@ -0,0 +1,36 @@
+// scalastyle:off
+/*
+* Copyright (C) 2022. Huawei Technologies Co., Ltd.
+* This program is distributed in the hope that it will be useful,
+* but WITHOUT ANY WARRANTY; without even the implied warranty of
+* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+* */
+package org.apache.spark.mllib.feature
+
+import scala.collection.mutable
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+
+class Word2VecSGHS(
+ val minCount: Int,
+ val window: Int,
+ val vectorSize: Int,
+ val vocabSize: Int,
+ val trainWordsCount: Long,
+ val learningRate: Double,
+ val numIterations: Int,
+ val seed: Long,
+ val maxSentenceLength: Int,
+ val regularization: Float,
+ val repetition: Int) extends Serializable with Logging {
+
+ def fit[S <: Iterable[String]](
+ dataset: RDD[S],
+ bcExpTable: Broadcast[Array[Float]],
+ bcVocab: Broadcast[Array[VocabWord]],
+ bcVocabHash: Broadcast[mutable.HashMap[String, Int]]): Array[Float] = {
+ null
+ }
+}
diff --git a/pom.xml b/pom.xml
index 79ee0f4..1acb85e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2,7 +2,7 @@
4.0.0
org.apache.spark
boostkit-ml
- 2.1.0
+ 2.2.0
${project.artifactId}
Spark ml algo
2020
--
Gitee