diff --git a/README.md b/README.md index 1f62365522f158e23b31cf9a2335997ef9597490..01491cc384d3e8cd0261f09df282e31c60a52fa2 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Introduction The machine learning algorithm library running on Kunpeng processors is an acceleration library that provides a rich set of high-level tools for machine learning algorithms. It is based on the original APIs of Apache [Spark 2.4.6](https://github.com/apache/spark/tree/v2.4.6), [breeze 0.13.1](https://github.com/scalanlp/breeze/tree/releases/v0.13.1) and [xgboost 1.1.0](https://github.com/dmlc/xgboost/tree/release_1.0.0). The acceleration library for greatly improves the computing power in big data scenarios. -The library provides 21 machine learning algorithms: support vector machine (SVM), random forest classifier (RFC), gradient boosting decision tree (GBDT), decision tree (DT), K-means clustering, linear regression, logistic regression algorithm, principal component analysis (PCA), principal component analysis for Sparse Matrix(SPCA), singular value decomposition (SVD), latent dirichlet allocation (LDA), prefix-projected pattern prowth (Prefix-Span), alternating least squares (ALS), K-nearest neighbors (KNN), Covariance, Density-based spatial clustering of applicaitons with noise (DBSCAN), Pearson, Spearman, XGboost, Inverse Document Frequency(IDF), and SimRank. You can find the latest documentation on the project web page. This README file contains only basic setup instructions. +The library provides 23 machine learning algorithms: support vector machine (SVM), random forest classifier (RFC), gradient boosting decision tree (GBDT), decision tree (DT), K-means clustering, linear regression, logistic regression algorithm, principal component analysis (PCA), principal component analysis for Sparse Matrix(SPCA), singular value decomposition (SVD), latent dirichlet allocation (LDA), prefix-projected pattern prowth (Prefix-Span), alternating least squares (ALS), K-nearest neighbors (KNN), Covariance, Density-based spatial clustering of applicaitons with noise (DBSCAN), Pearson, Spearman, XGboost, Inverse Document Frequency(IDF), SimRank, Decision Tree Bucket(DTB) and Word2Vec. You can find the latest documentation on the project web page. This README file contains only basic setup instructions. You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions. @@ -25,13 +25,13 @@ Building And Packageing mvn clean package -(3) Obtain "boostkit-ml-core_2.11-2.1.0-spark2.4.6.jar" under the "Spark-ml-algo-lib/ml-core/target" directory. +(3) Obtain "boostkit-ml-core_2.11-2.2.0-spark2.4.6.jar" under the "Spark-ml-algo-lib/ml-core/target" directory. - Obtain "boostkit-ml-acc_2.11-2.1.0-spark2.4.6.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory. + Obtain "boostkit-ml-acc_2.11-2.2.0-spark2.4.6.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory. - Obtain "boostkit-xgboost4j_2.11-2.1.0.jar" under the "Spark-ml-algo-lib/ml-xgboost/jvm-packages/boostkit-xgboost4j/target" directory. + Obtain "boostkit-xgboost4j_2.11-2.2.0.jar" under the "Spark-ml-algo-lib/ml-xgboost/jvm-packages/boostkit-xgboost4j/target" directory. - Obtain "boostkit-xgboost4j-spark2.4.6_2.11-2.1.0.jar" under the "Spark-ml-algo-lib/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/target" directory. + Obtain "boostkit-xgboost4j-spark2.4.6_2.11-2.2.0.jar" under the "Spark-ml-algo-lib/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/target" directory. Contribution Guidelines diff --git a/ml-accelerator/pom.xml b/ml-accelerator/pom.xml index 5c9f082df282f805299eca861af0c99b2a197dbc..9b4611974be3ee18b5473154bcba32aab2c2d363 100644 --- a/ml-accelerator/pom.xml +++ b/ml-accelerator/pom.xml @@ -2,12 +2,12 @@ org.apache.spark boostkit-ml - 2.1.0 + 2.2.0 4.0.0 boostkit-ml-acc_2.11 - 2.1.0 + 2.2.0 ${project.artifactId} Spark ml algo accelerator diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5af3aa6412af027329580c03ba673cdbd78d6c1 --- /dev/null +++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.param.shared.HasWeightCol +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.DecisionTreeBucket +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} + + +private[ml] trait DecisionTreeBucketizerParams extends Params + with DecisionTreeClassifierParams with HasWeightCol { + + /** + * Param for bucketedFeatures column name. + * @group param + */ + final val bucketedFeaturesCol: Param[String] = + new Param[String](this, "bucketedFeaturesCol", "bucketedFeatures column name") + + setDefault(bucketedFeaturesCol, "bucketedFeatures") + + /** @group getParam */ + final def getBucketedFeaturesCol: String = $(bucketedFeaturesCol) + + /** + * Validates and transforms the input schema with the provided param map. + * copy from [[org.apache.spark.ml.PredictorParams]].validateAndTransformSchema + * + * @param schema input schema + * @param fitting whether this is in fitting + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkNumericType(schema, $(labelCol)) + + this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + SchemaUtils.checkNumericType(schema, $(p.weightCol)) + } + case _ => + } + } + SchemaUtils.appendColumn(schema, $(bucketedFeaturesCol), new VectorUDT) + } + +} + +/** + * Decision tree bucketing algorithm for data discretization. + */ +@Since("1.4.0") +class DecisionTreeBucketizer @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[DecisionTreeBucketModel] + with DecisionTreeBucketizerParams with DecisionTreeClassifierParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("dtb")) + + def setLabelCol(value: String): this.type = set(labelCol, value) + + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + def setBucketedFeaturesCol(value: String): this.type = set(bucketedFeaturesCol, value) + + // Override parameter setters from parent trait for Java API compatibility. + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = set(impurity, value) + + /** @group setParam */ + @Since("1.6.0") + override def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * + * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) + * and features (`Vector`). + * @param numClasses Number of classes label can take. Labels must be integers in the range + * [0, numClasses). + * @note Throws `SparkException` if any label is a non-integer or is negative + */ + private[ml] def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { + require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + + s" $numClasses, but requires numClasses > 0.") + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => + require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + + s" dataset with invalid label $label. Labels must be integers in range" + + s" [0, $numClasses).") + LabeledPoint(label, features) + } + } + + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in `extractLabeledPoints()`. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + private[ml] def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = { + MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => + // Get number of classes from dataset itself. + val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) + if (maxLabelRow.isEmpty || maxLabelRow.head.isNullAt(0)) { + throw new SparkException("ML algorithm was given empty dataset.") + } + val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) + require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" + + s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})") + val numClasses = maxDoubleLabel.toInt + 1 + require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" + + s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" + + s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" + + s" classes, specify numClasses explicitly in the metadata; this can be done by applying" + + s" StringIndexer to the label column.") + logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" + + s" labelCol=$labelCol since numClasses was not specified in the column metadata.") + numClasses + } + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, true) + } + + override def fit(dataset: Dataset[_]): DecisionTreeBucketModel = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, logging = true) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + // Cast WeightCol to DoubleType and keep the metadata. + val casted = this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + val weightMeta = dataset.schema($(p.weightCol)).metadata + labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta) + } else { + labelCasted + } + case _ => labelCasted + } + + copyValues(train(casted).setParent(this)) + } + + private[ml] def train(dataset: Dataset[_]): DecisionTreeBucketModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = getNumClasses(dataset) + + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + + val trees = DecisionTreeBucket.run(oldDataset, strategy, getSeed, Some(instr)) + .map(_.asInstanceOf[DecisionTreeClassificationModel]) + + val numFeatures = oldDataset.first().features.size + val m = new DecisionTreeBucketModel(uid, trees, numFeatures, numClasses) + instr.logSuccess(m) + m + } + + /** (private[ml]) Train decision trees on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeBucketModel = { + val instr = Instrumentation.create(this, data) + instr.logParams(params: _*) + + val trees = DecisionTreeBucket.run(data, oldStrategy, getSeed, Some(instr)) + .map(_.asInstanceOf[DecisionTreeClassificationModel]) + + val numFeatures = data.first().features.size + val m = new DecisionTreeBucketModel(uid, trees, numFeatures, oldStrategy.numClasses) + instr.logSuccess(m) + m + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) + } + + @Since("1.4.1") + override def copy(extra: ParamMap): DecisionTreeBucketizer = defaultCopy(extra) +} + +@Since("1.4.0") +object DecisionTreeBucketizer extends DefaultParamsReadable[DecisionTreeBucketizer] { + /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): DecisionTreeBucketizer = super.load(path) +} + +/** + * Decision tree bucket model for data discretization. + * @param _trees Decision trees of all features. + */ +@Since("1.4.0") +class DecisionTreeBucketModel private[ml] ( + @Since("1.5.0") override val uid: String, + private val _trees: Array[DecisionTreeClassificationModel], + val numFeatures: Int, + val numClasses: Int) + extends Model[DecisionTreeBucketModel] + with DecisionTreeBucketizerParams with DecisionTreeClassifierParams + with TreeEnsembleModel[DecisionTreeClassificationModel] + with MLWritable with Serializable { + + require(_trees.nonEmpty, "DecisionTreeBucketModel requires at least 1 tree.") + + /** + * Construct a decision tree bucket model, with all trees weighted equally. + * + * @param trees Component trees + */ + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = + this(Identifiable.randomUID("dtb"), trees, numFeatures, numClasses) + + def getNumTrees: Int = _trees.length + + @Since("1.4.0") + override def trees: Array[DecisionTreeClassificationModel] = _trees + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) + + @Since("1.4.0") + override def treeWeights: Array[Double] = _treeWeights + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, false) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + // val outputData = super.transform(dataset) + val leafUDF = udf { features: Vector => predictLeaf(features) } + dataset.withColumn($(bucketedFeaturesCol), leafUDF(col($(featuresCol)))) + } + + @Since("1.4.0") + override def copy(extra: ParamMap): DecisionTreeBucketModel = { + copyValues(new DecisionTreeBucketModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) + } + + @Since("1.4.0") + override def toString: String = { + s"DecisionTreeBucketModel (uid=$uid) with $getNumTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } + + @Since("2.0.0") + override def write: MLWriter = + new DecisionTreeBucketModel.DecisionTreeBucketModelWriter(this) +} + +@Since("2.0.0") +object DecisionTreeBucketModel extends MLReadable[DecisionTreeBucketModel] { + + @Since("2.0.0") + override def read: MLReader[DecisionTreeBucketModel] = + new DecisionTreeBucketModelReader + + @Since("2.0.0") + override def load(path: String): DecisionTreeBucketModel = super.load(path) + + private[DecisionTreeBucketModel] + class DecisionTreeBucketModelWriter(instance: DecisionTreeBucketModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Note: numTrees is not currently used, but could be nice to store for fast querying. + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) + } + } + + private class DecisionTreeBucketModelReader + extends MLReader[DecisionTreeBucketModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeBucketModel].getName + private val treeClassName = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): DecisionTreeBucketModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeClassificationModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"DecisionTreeBucketModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new DecisionTreeBucketModel(metadata.uid, trees, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + /** Convert a model from the old API */ + private[ml] def fromOld( + oldModel: OldRandomForestModel, + parent: DecisionTreeBucketizer, + categoricalFeatures: Map[Int, Int], + numClasses: Int, + numFeatures: Int = -1): DecisionTreeBucketModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to DecisionTreeBucketModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtb") + new DecisionTreeBucketModel(uid, newTrees, numFeatures, numClasses) + } +} diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed31c30df90c691718e1aca1c0b8e0696b08efa0 --- /dev/null +++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Utils, VersionUtils} + +/** + * Params for [[Word2Vec]] and [[Word2VecModel]]. + */ +private[feature] trait Word2VecBase extends Params + with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { + + /** + * The dimension of the code that you want to transform from words. + * Default: 100 + * @group param + */ + final val vectorSize = new IntParam( + this, "vectorSize", "the dimension of codes after transforming from words (> 0)", + ParamValidators.gt(0)) + setDefault(vectorSize -> 100) + + /** @group getParam */ + def getVectorSize: Int = $(vectorSize) + + /** + * The window size (context words from [-window, window]). + * Default: 5 + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window]) (> 0)", + ParamValidators.gt(0)) + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + + /** + * Number of partitions for sentences of words. + * Default: 1 + * @group param + */ + final val numPartitions = new IntParam( + this, "numPartitions", "number of partitions for sentences of words (> 0)", + ParamValidators.gt(0)) + setDefault(numPartitions -> 1) + + /** @group getParam */ + def getNumPartitions: Int = $(numPartitions) + + /** + * The minimum number of times a token must appear to be included in the word2vec model's + * vocabulary. + * Default: 5 + * @group param + */ + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + + "appear to be included in the word2vec model's vocabulary (>= 0)", ParamValidators.gtEq(0)) + setDefault(minCount -> 5) + + /** @group getParam */ + def getMinCount: Int = $(minCount) + + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size. + * Default: 1000 + * @group param + */ + final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + + "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + + "be divided into chunks up to the size (> 0)", ParamValidators.gt(0)) + setDefault(maxSentenceLength -> 1000) + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** + * Sets the regularization coefficient. + * Default: 0.05f + * @group param + */ + final val regularization = new FloatParam(this, "regularization", "Regularization coefficient") + + /** @group getParam */ + def getRegularization: Float = $(regularization) + setDefault(regularization -> 0.05f) + + /** + * Sets the number of repetitions of data. + * Default: 3 + * @group param + */ + final val repetition = new IntParam(this, "repetition", "The number of repetitions of data") + + /** @group getParam */ + def getRepetition: Int = $(repetition) + setDefault(repetition -> 2) + + + setDefault(stepSize -> 0.025) + setDefault(maxIter -> 1) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } +} + +/** + * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + * natural language processing or machine learning process. + */ +@Since("1.4.0") +final class Word2Vec @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[Word2VecModel] with Word2VecBase with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("w2v")) + + /** @group setParam */ + @Since("1.4.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("1.4.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + @Since("1.4.0") + def setVectorSize(value: Int): this.type = set(vectorSize, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setWindowSize(value: Int): this.type = set(windowSize, value) + + /** @group setParam */ + @Since("1.4.0") + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group setParam */ + @Since("1.4.0") + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + @Since("1.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("1.4.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("1.4.0") + def setMinCount(value: Int): this.type = set(minCount, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) + + /** @group setParam */ + def setRegularization(value: Float): this.type = set(regularization, value) + + /** @group setParam */ + def setRepetition(value: Int): this.type = set(repetition, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val wordVectors = new feature.Word2Vec() + .setLearningRate($(stepSize)) + .setMinCount($(minCount)) + .setNumIterations($(maxIter)) + .setNumPartitions($(numPartitions)) + .setSeed($(seed)) + .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) + .setMaxSentenceLength($(maxSentenceLength)) + .setRegularization($(regularization)) + .setRepetition($(repetition)) + .fit(input) + copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) + } + + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("1.4.1") + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) +} + +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + +/** + * Model fitted by [[Word2Vec]]. + */ +@Since("1.4.0") +class Word2VecModel private[ml] ( + @Since("1.4.0") override val uid: String, + @transient private val wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + @Since("1.5.0") + @transient lazy val getVectors: DataFrame = { + val spark = SparkSession.builder().getOrCreate() + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + spark.createDataFrame(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word. + */ + @Since("1.5.0") + def findSynonyms(word: String, num: Int): DataFrame = { + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word vector. + */ + @Since("2.0.0") + def findSynonyms(vec: Vector, num: Int): DataFrame = { + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(vec, num) + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(word, num) + } + + /** @group setParam */ + @Since("1.4.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("1.4.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a sentence column to a vector column to represent the whole sentence. The transform + * is performed by averaging all word vectors it contains. + */ + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) + val d = $(vectorSize) + val word2Vec = udf { sentence: Seq[String] => + if (sentence.isEmpty) { + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) + } else { + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) + } + } + BLAS.scal(1.0 / sentence.size, sum) + sum + } + } + dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) + } + + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("1.4.1") + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private case class Data(word: String, vector: Array[Float]) + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + + val wordVectors = instance.wordVectors.getVectors + val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) + val spark = sparkSession + import spark.implicits._ + spark.createDataset[(String, Array[Float])](wordVectors.toSeq) + .repartition(numPartitions) + .map { case (word, vector) => Data(word, vector) } + .toDF() + .write + .parquet(dataPath) + } + } + + private[feature] + object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ + def calculateNumberOfPartitions( + bufferSizeInBytes: Long, + numWords: Int, + vectorSize: Int): Int = { + val floatSize = 4L // Use Long to help avoid overflow + val averageWordSize = 15 + // Calculate the approximate size of the model. + // Assuming an average word size of 15 bytes, the formula is: + // (floatSize * vectorSize + 15) * numWords + val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords + val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1 + require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " + + s"partitions to save this model, which is too large. Try increasing " + + s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.") + numPartitions.toInt + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val spark = sparkSession + import spark.implicits._ + + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) + + val dataPath = new Path(path, "data").toString + + val oldModel = if (major < 2 || (major == 2 && minor < 2)) { + val data = spark.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + new feature.Word2VecModel(wordIndex, wordVectors) + } else { + val wordVectorsMap = spark.read.parquet(dataPath).as[Data] + .collect() + .map(wordVector => (wordVector.word, wordVector.vector)) + .toMap + new feature.Word2VecModel(wordVectorsMap) + } + + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) +} diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala new file mode 100644 index 0000000000000000000000000000000000000000..042e4b866aa2bea1b49fa0347700b6b6a6a4c696 --- /dev/null +++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeBucket.scala @@ -0,0 +1,1225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.RandomForest.NodeIndexInfo +import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} + + +/** + * ALGORITHM + * + * This is a sketch of the algorithm to help new developers. + * + * The algorithm partitions data by instances (rows). + * On each iteration, the algorithm splits a set of nodes. In order to choose the best split + * for a given node, sufficient statistics are collected from the distributed data. + * For each node, the statistics are collected to some worker node, and that worker selects + * the best split. + * + * This setup requires discretization of continuous features. This binning is done in the + * findSplits() method during initialization, after which each continuous feature becomes + * an ordered discretized feature with at most maxBins possible values. + * + * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes + * lie at the periphery of the tree being trained. If multiple trees are being trained at once, + * then this queue contains nodes from all of them. Each iteration works roughly as follows: + * On the master node: + * - Some number of nodes are pulled off of the queue (based on the amount of memory + * required for their sufficient statistics). + * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + * features are chosen for each node. See method selectNodesToSplit(). + * On worker nodes, via method findBestSplits(): + * - The worker makes one pass over its subset of instances. + * - For each (tree, node, feature, split) tuple, the worker collects statistics about + * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + * from the queue for this iteration. The set of features considered can also be limited + * based on featureSubsetStrategy. + * - For each node, the statistics for that node are aggregated to a particular worker + * via reduceByKey(). The designated worker chooses the best (feature, split) pair, + * or chooses to stop splitting if the stopping criteria are met. + * On the master node: + * - The master collects all decisions about splitting nodes and updates the model. + * - The updated model is passed to the workers on the next iteration. + * This process continues until the node queue is empty. + * + * Most of the methods in this implementation support the statistics aggregation, which is + * the heaviest part of the computation. In general, this implementation is bound by either + * the cost of statistics computation on workers or by communicating the sufficient statistics. + */ +private[spark] object DecisionTreeBucket extends Logging { + + /** + * Train a random forest. + * + * @param input Training data: RDD of `LabeledPoint` + * @return an unweighted set of trees + */ + def run( + input: RDD[LabeledPoint], + strategy: OldStrategy, + seed: Long, + instr: Option[Instrumentation[_]], + parentUID: Option[String] = None): Array[DecisionTreeModel] = { + val exParams = RFUtils.parseExtraParams(input, strategy) + runX(input, strategy, seed, instr, exParams, parentUID) + } + + /** + * Train a random forest. + * + * @param input Training data: RDD of `LabeledPoint` + * @return an unweighted set of trees + */ + def runX( + input: RDD[LabeledPoint], + strategy: OldStrategy, + seed: Long, + instr: Option[Instrumentation[_]], + extraParams: RFExtraParams, + parentUID: Option[String] = None): Array[DecisionTreeModel] = { + + DecisionTreeBucketInfo.timerResult = "" + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val binnedFeaturesType = BinnedFeaturesDataType.withName(extraParams.featuresDataType) + val retaggedInput = input.retag(classOf[LabeledPoint]) + val featureSubsetStrategy = "1" + var numTrees: Int = 0 + // featureSubsetStrategy: The number of features to consider for splits at each tree node. + // featureSubsetStrategy: default value is "auto" for random forest. + // impurity: default value is "gini" for random forest. + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + numTrees = metadata.numTrees + logWarning(s"decisionTreeMetadata details: ${metadata.numFeatures}," + + s" ${metadata.numExamples}, ${metadata.numClasses: Int}, ${metadata.maxBins: Int}," + + s" ${metadata.featureArity}, ${metadata.unorderedFeatures.mkString("[", ";", "]")}," + + s" ${metadata.impurity}, ${metadata.quantileStrategy}, ${metadata.maxDepth: Int}," + + s" ${metadata.minInstancesPerNode: Int}, ${metadata.minInfoGain: Double}," + + s" ${metadata.numTrees: Int}, ${metadata.numFeaturesPerNode: Int}, ${binnedFeaturesType}") + instr match { + case Some(instrumentation) => + instrumentation.logNumFeatures(metadata.numFeatures) + instrumentation.logNumClasses(metadata.numClasses) + case None => + logInfo(s"numFeatures: ${metadata.numFeatures}") + logInfo(s"numClasses: ${metadata.numClasses}") + } + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplits") + val splits = findSplits(retaggedInput, metadata, seed) + val baseSplits = + splits.map(v => v.zipWithIndex.map{case (split, binIdx) => Split.toBase(split, binIdx)}) + timer.stop("findSplits") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePointX.convertToTreeRDD(retaggedInput, splits, metadata, binnedFeaturesType) + + val withReplacement = numTrees > 1 + + // Default value of subsamplingRate is 1 for random forest. + val baggedInputOri = BaggedPoint.convertToBaggedRDD(treeInput, strategy.subsamplingRate, + numTrees, withReplacement, seed, metadata.oneFeaturePerTree) + + val baggedInput = RFUtils.transformBaggedRDD(baggedInputOri, extraParams) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug(s"max memory usage for aggregates = ${maxMemoryUsage} bytes.") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + // Default value of useNodeIdCache is false for random forest. + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + + /* + Stack of nodes to train: (treeIndex, node) + The reason this is a stack is that we train many trees at once, but we want to focus on + completing trees, rather than training all simultaneously. If we are splitting nodes from + 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue + training the same tree in the next iteration. This focus allows us to send fewer trees to + workers on each iteration; see topNodesForGroup below. + */ + val nodeStack = new mutable.ArrayStack[(Int, LearningNodeX)] + + val rng = new Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes = Array.fill[LearningNodeX](numTrees)(LearningNodeX.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex)))) + + timer.stop("init") + + while (nodeStack.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + // nodesForGroup: treeIndex --> learningNodes in tree + // treeToNodeToIndexInfo: treeIndex --> (global) learningNodes index in tree + // --> (node index in group, feature indices). + val (nodesForGroup, treeToNodeToIndexInfo) = + DecisionTreeBucket.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.nonEmpty, + s"DecisionTreeBucket selected empty nodesForGroup. Error for unknown reason.") + + // Only send trees to worker if they contain nodes being split this iteration. + // topNodesForGroup: treeIndex --> top node in tree + val topNodesForGroup: Map[Int, LearningNodeX] = + nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + DecisionTreeBucket.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup, + treeToNodeToIndexInfo, baseSplits, nodeStack, timer, nodeIdCache, Some(extraParams)) + timer.stop("findBestSplits") + } + + baggedInput.unpersist() + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + DecisionTreeBucketInfo.timerResult = timer.toString() + + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e: IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } + } + + val numFeatures = metadata.numFeatures + + parentUID match { + case Some(uid) => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode(splits), numFeatures, + strategy.getNumClasses) + } + } else { + topNodes.map { rootNode => + new DecisionTreeRegressionModel(uid, rootNode.toNode(splits), numFeatures) + } + } + case None => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode(splits), numFeatures, + strategy.getNumClasses) + } + } else { + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(splits), numFeatures)) + } + } + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePointX, + splits: Array[Array[SplitBase]], + unorderedFeatures: Set[Int], + instanceWeight: Int, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.length + } else { + // Use all features + agg.metadata.numFeatures + } + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures.get(featureIndex) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) + // Update the left or right bin for each split. + val numSplits = agg.metadata.numSplits(featureIndex) + val featureSplits = splits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + } + splitIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures.get(featureIndex) + agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + } + featureIndexIdx += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePointX, + instanceWeight: Int, + featuresForNode: Option[Array[Int]]): Unit = { + val label = treePoint.label + + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.length) { + val binIndex = treePoint.binnedFeatures.get(featuresForNode.get.apply(featureIndexIdx)) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures.get(featureIndex) + agg.update(featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } + } + } + + /** + * Given a group of nodes, this finds the best split for each node. + * + * @param input Training data: RDD of [[TreePointX]] + * @param metadata Learning and dataset metadata + * @param topNodesForGroup For each tree in group, tree index -> root node. + * Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> (global) learningNodes index in tree + * --> (node index in group, feature indices) + * feature indices: probably parts of full features. + * Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param nodeStack Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. + */ + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePointX]], + metadata: DecisionTreeMetadata, + topNodesForGroup: Map[Int, LearningNodeX], + nodesForGroup: Map[Int, Array[LearningNodeX]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], + splits: Array[Array[SplitBase]], + nodeStack: mutable.ArrayStack[(Int, LearningNodeX)], + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None, + extraParams: Option[RFExtraParams] = None): Unit = { + + /* + * The high-level descriptions of the best split optimizations are noted here. + * + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + val bcVariables = if (extraParams.isEmpty) false else extraParams.get.bcVariables + /** numNodes: Number of nodes in this group */ + val numNodes = nodesForGroup.values.map(_.length).sum + logDebug(s"numNodes = ${numNodes}") + logDebug(s"numFeatures = ${metadata.numFeatures}") + logDebug(s"numClasses = ${metadata.numClasses}") + logDebug(s"isMulticlass = ${metadata.isMulticlass}") + logDebug(s"isMulticlassWithCategoricalFeatures = " + + s"${metadata.isMulticlassWithCategoricalFeatures}") + logDebug(s"using nodeIdCache = ${ nodeIdCache.nonEmpty.toString}") + + val groupInfo = RFUtils.getGroupInfo(numNodes, treeToNodeToIndexInfo, extraParams) + + val splitsBc = if (bcVariables) Some(input.sparkContext.broadcast(splits)) else Option.empty + val splitsOption = if (bcVariables) Option.empty else Some(splits) + + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: NodeIndexInfo, + agg: Array[DTStatsAggregator], + splitsBcv: Array[Array[SplitBase]], + baggedPoint: BaggedPoint[TreePointX]): Unit = { + if (RFUtils.isValidNodeInfo(nodeInfo, agg)) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = 1 + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splitsBcv, + metadata.unorderedFeatures, instanceWeight, featuresForNode) + } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) + } + } + + /** + * Performs a sequential aggregation over a partition. + * + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return agg + */ + def binSeqOp( + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePointX], + splitsBcv: Array[Array[SplitBase]], + sampleId: Short): Array[DTStatsAggregator] = { + // TODO: treeToNodeToIndexInfo and topNodesForGroup(include sub-nodes) weren't broadcast. + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + if (RFUtils.isValidSample(baggedPoint, groupInfo, treeIndex, sampleId)) { + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splitsBcv) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, splitsBcv, baggedPoint) + } + } + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + splitsBcv: Array[Array[SplitBase]], + dataPoint: (BaggedPoint[TreePointX], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, splitsBcv, baggedPoint) + } + + agg + } + + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group. + */ + def getNodeToFeatures( + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } + } + Some(mutableNodeToFeatures.toMap) + } + } + + // array of nodes to train indexed by node index in group + val nodes = new Array[LearningNodeX](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + + // Calculate best splits for all nodes in the group + timer.start("chooseSplits") + + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield a (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + // nodeToFeatures: node index in group -> selected feature indexes + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + + /** partitionAggregates RDD: node index in group --> nodeStats */ + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + val splitsBcv = if (bcVariables) splitsBc.get.value else splitsOption.get + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, splitsBcv, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { + input.mapPartitions { points => + val (firstPointOption, nodeStatsAggregators) = + RFUtils.initNodeStatsAgg(numNodes, nodeToFeaturesBc, metadata, points, groupInfo) + if (firstPointOption.isEmpty) { + Iterator.empty + } else { + val firstPoint = firstPointOption.get + val sampleId = firstPoint.sampleId + + val splitsBcv = if (bcVariables) splitsBc.get.value else splitsOption.get + binSeqOp(nodeStatsAggregators, firstPoint, splitsBcv, sampleId) + + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _, splitsBcv, sampleId)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex + .filter(v => RFUtils.isValidAgg(v._1)).map(_.swap).iterator + } + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + val splitsBcv = if (bcVariables) splitsBc.get.value else splitsOption.get + // find best split for each node + val (split: SplitBase, stats: ImpurityStats) = + binsToBestSplit(aggStats, splitsBcv, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + }.collectAsMap() + + timer.stop("chooseSplits") + + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: SplitBase, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug(s"best split = ${split}") + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNodeX.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug(s"Node = ${node}") + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNodeX.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftChild = Some(LearningNodeX(LearningNodeX.leftChildIndex(nodeIndex), + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some(LearningNodeX(LearningNodeX.rightChildIndex(nodeIndex), + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeStack.push((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeStack.push((treeIndex, node.rightChild.get)) + } + + logDebug(s"leftChildIndex = ${node.leftChild.get.id}" + + s", impurity = ${stats.leftImpurity}") + logDebug(s"rightChildIndex = ${node.rightChild.get.id}" + + s", impurity = ${stats.rightImpurity}") + } + } + } + + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits) + } + } + + /** + * Calculate the impurity statistics for a given (feature, split) based upon left/right + * aggregates. + * + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private def calculateImpurityStats( + stats: ImpurityStats, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[SplitBase]], + featuresForNode: Option[Array[Int]], + node: LearningNodeX): (SplitBase, ImpurityStats) = { + + // Calculate InformationGain and ImpurityStats if current node is top node + val level = LearningNodeX.indexToLevel(node.id) + var gainAndImpurityStats: ImpurityStats = if (level == 0) { + null + } else { + node.stats + } + + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + + // For each (feature, split), calculate the gain, and select the best (feature, split). + val splitsAndImpurityInfo = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) + } else { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict + } + } else { + Double.MaxValue + } + (featureValue, centroid) + } + + logDebug(s"Centroids for categorical variable: ${centroidForCategories.mkString(",")}") + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + } + + val (bestSplit, bestSplitStats) = + if (splitsAndImpurityInfo.isEmpty) { + // If no valid splits for features, then this split is invalid, + // return invalid information gain stats. Take any split and continue. + // Splits is empty, so arbitrarily choose to split on any threshold + val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) + val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() + if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { + (new ContinuousSplit(dummyFeatureIndex, 0), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } else { + val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) + (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } + } else { + splitsAndImpurityInfo.maxBy(_._2.gain) + } + (bestSplit, bestSplitStats) + } + + /** + * Returns splits for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) "unordered features" + * For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * (b) "ordered features" + * For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one bin per category. + * + * @param input Training data: RDD of [[LabeledPoint]] + * @param metadata Learning and dataset metadata + * @param seed random seed + * @return Splits, an Array of [[Split]] + * of size (numFeatures, numSplits) + */ + protected[tree] def findSplits( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + seed: Long): Array[Array[Split]] = { + + logDebug(s"isMulticlass = ${metadata.isMulticlass}") + + val numFeatures = metadata.numFeatures + + // Sample the input only if there are continuous features. + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug(s"fraction of data used for calculating quantiles = ${fraction}") + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) + } else { + input.sparkContext.emptyRDD[LabeledPoint] + } + + findSplitsBySorting(sampledInput, metadata, continuousFeatures) + } + + private def findSplitsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { + + val continuousSplits: scala.collection.Map[Int, Array[Split]] = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (idx, samples) => + val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + }.collectAsMap() + } + + val numFeatures = metadata.numFeatures + val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { + case i if metadata.isContinuous(i) => + val split = continuousSplits(i) + metadata.setNumSplits(i, split.length) + split + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(i) + Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(i, categories.toArray, featureArity) + } + + case i if metadata.isCategorical(i) => + // Ordered features + // Splits are constructed as needed during training. + Array.empty[Split] + } + splits + } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of split thresholds + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Iterable[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits: Array[Double] = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + val possibleSplits = valueCounts.length - 1 + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray + } else { + // stride between splits + val stride: Double = numSamples.toDouble / (numSplits + 1) + logDebug(s"stride = ${stride}") + + // iterate `valueCount` to find splits + val splitsBuilder = mutable.ArrayBuilder.make[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 + targetCount += stride + } + index += 1 + } + + splitsBuilder.result() + } + } + splits + } + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeStack Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeStack: mutable.ArrayStack[(Int, LearningNodeX)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: Random): (Map[Int, Array[LearningNodeX]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNodeX]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + // If maxMemoryInMB is set very small, we want to still try to split 1 node, + // so we allow one iteration if memUsage == 0. + var groupDone = false + while (nodeStack.nonEmpty && !groupDone) { + val (treeIndex, node) = nodeStack.top + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + if (metadata.oneFeaturePerTree) { + Some(Array(treeIndex)) + } else { + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + } + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = DecisionTreeBucket.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { + nodeStack.pop() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNodeX]()) += + node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + numNodesInGroup += 1 + memUsage += nodeMemUsage + } else { + groupDone = true + } + } + if (memUsage > maxMemoryUsage) { + // If maxMemoryUsage is 0, we should still allow splitting 1 node. + logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" + + s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + + s" $numNodesInGroup nodes in this iteration.") + } + logWarning(f"[group] actualMemUsage: ${memUsage/(1024d*1024d)}%.2f MB," + + f" maxMemoryUsage: ${maxMemoryUsage/(1024d*1024d)}%.2f MB.") + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[LearningNodeX]] = + mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } +} + +object DecisionTreeBucketInfo { + var timerResult: String = "" +} diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index f54eab6fa83ee9157f537f2959749a0a160ad2eb..d5d5172a6e57d882193c06917e30541c7b4891b4 100644 --- a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -1,5 +1,5 @@ /* -* Copyright (C) 2021. Huawei Technologies Co., Ltd. +* Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. @@ -84,6 +84,30 @@ private[spark] trait DecisionTreeModel { /** Convert to spark.mllib DecisionTreeModel (losing some information) */ private[spark] def toOld: OldDecisionTreeModel + + /** + * @return an iterator that traverses (DFS, left to right) the leaves + * in the subtree of this node. + */ + private def leafIterator(node: Node): Iterator[LeafNode] = { + node match { + case l: LeafNode => Iterator.single(l) + case n: InternalNode => + leafIterator(n.leftChild) ++ leafIterator(n.rightChild) + } + } + + @transient private lazy val leafIndices: Map[LeafNode, Int] = { + leafIterator(rootNode).zipWithIndex.toMap + } + + /** + * @return The index of the leaf corresponding to the feature vector. + * Leaves are indexed in pre-order from 0. + */ + def predictLeaf(features: Vector): Double = { + leafIndices(rootNode.predictImpl(features)).toDouble + } } /** @@ -124,6 +148,15 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { /** Total number of nodes, summed over all trees in the ensemble. */ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum + + /** + * @return The indices of the leaves corresponding to the feature vector. + * Leaves are indexed in pre-order from 0. + */ + def predictLeaf(features: Vector): Vector = { + val indices = trees.map(_.predictLeaf(features)) + Vectors.dense(indices) + } } private[ml] object TreeEnsembleModel { diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala new file mode 100644 index 0000000000000000000000000000000000000000..f4d8b9fc4af15f9ee5c415abd56f677a3eba475a --- /dev/null +++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Since +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging +import org.apache.spark.rdd._ +import org.apache.spark.util.Utils + +/** + * Entry in vocabulary + */ +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen: Int +) + +/** + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. The variable names in the implementation + * matches the original C implementation. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality. + */ +@Since("1.1.0") +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var learningRate = 0.025 + private var regularization = 0.05f + private var repetition = 2 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + private var minCount = 5 + private var maxSentenceLength = 1000 + + /** + * Sets the regularization coefficient. + */ + def setRegularization(regularization: Float): this.type = { + require(regularization >= 0, + s"The value of regularization must not be smaller than 0 but got ${regularization}") + this.regularization = regularization + this + } + + /** + * Sets the number of repetitions of data. + */ + def setRepetition(repetition: Int): this.type = { + require(repetition >= 0, + s"The number of repetitions of data must not be smaller than 0 but got ${repetition}") + this.repetition = repetition + this + } + + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size (default: 1000) + */ + @Since("2.0.0") + def setMaxSentenceLength(maxSentenceLength: Int): this.type = { + require(maxSentenceLength > 0, + s"Maximum length of sentences must be positive but got ${maxSentenceLength}") + this.maxSentenceLength = maxSentenceLength + this + } + + /** + * Sets vector size (default: 100). + */ + @Since("1.1.0") + def setVectorSize(vectorSize: Int): this.type = { + require(vectorSize > 0, + s"vector size must be positive but got ${vectorSize}") + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + @Since("1.1.0") + def setLearningRate(learningRate: Double): this.type = { + require(learningRate > 0, + s"Initial learning rate must be positive but got ${learningRate}") + this.learningRate = learningRate + this + } + + /** + * Sets number of partitions (default: 1). Use a small number for accuracy. + */ + @Since("1.1.0") + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, + s"Number of partitions must be positive but got ${numPartitions}") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + @Since("1.1.0") + def setNumIterations(numIterations: Int): this.type = { + require(numIterations >= 0, + s"Number of iterations must be nonnegative but got ${numIterations}") + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + @Since("1.1.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + require(window > 0, + s"Window of words must be positive but got ${window}") + this.window = window + this + } + + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + * model's vocabulary (default: 5). + */ + @Since("1.3.0") + def setMinCount(minCount: Int): this.type = { + require(minCount >= 0, + s"Minimum number of times must be nonnegative but got ${minCount}") + this.minCount = minCount + this + } + + /** + * Set parameters of algorithm. + */ + private def setParams(conf: SparkConf): Unit = { + setRegularization(Word2Vec.getSparkConfFloatValue( + conf, "spark.boostkit.mllib.feature.word2vec.regularization", 0.05f)) + setRepetition(Word2Vec.getSparkConfIntValue( + conf, "spark.boostkit.mllib.feature.word2vec.repetition", 2)) + } + + private val EXP_TABLE_SIZE = 1000 + private val MAX_EXP = 6 + private val MAX_CODE_LENGTH = 40 + + /** context words from [-window, window] */ + private var window = 5 + + private var trainWordsCount = 0L + private var vocabSize = 0 + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] + + private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { + val words = dataset.flatMap(x => x) + + vocab = words.map(w => (w, 1)) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) + .collect() + .sortWith((a, b) => a.cn > b.cn) + + vocabSize = vocab.length + require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + + "the setting of minCount, which could be large enough to remove all your words in sentences.") + + var a = 0 + while (a < vocabSize) { + vocabHash += vocab(a).word -> a + trainWordsCount += vocab(a).cn + a += 1 + } + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") + } + + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) + var i = 0 + while (i < EXP_TABLE_SIZE) { + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) + expTable(i) = (tmp / (tmp + 1.0)).toFloat + i += 1 + } + expTable + } + + private def createBinaryTree(): Unit = { + val count = new Array[Long](vocabSize * 2 + 1) + val binary = new Array[Int](vocabSize * 2 + 1) + val parentNode = new Array[Int](vocabSize * 2 + 1) + val code = new Array[Int](MAX_CODE_LENGTH) + val point = new Array[Int](MAX_CODE_LENGTH) + var a = 0 + while (a < vocabSize) { + count(a) = vocab(a).cn + a += 1 + } + while (a < 2 * vocabSize) { + count(a) = 1e9.toInt + a += 1 + } + var pos1 = vocabSize - 1 + var pos2 = vocabSize + + var min1i = 0 + var min2i = 0 + + a = 0 + while (a < vocabSize - 1) { + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min1i = pos1 + pos1 -= 1 + } else { + min1i = pos2 + pos2 += 1 + } + } else { + min1i = pos2 + pos2 += 1 + } + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min2i = pos1 + pos1 -= 1 + } else { + min2i = pos2 + pos2 += 1 + } + } else { + min2i = pos2 + pos2 += 1 + } + count(vocabSize + a) = count(min1i) + count(min2i) + parentNode(min1i) = vocabSize + a + parentNode(min2i) = vocabSize + a + binary(min2i) = 1 + a += 1 + } + // Now assign binary code to each vocabulary word + var i = 0 + a = 0 + while (a < vocabSize) { + var b = a + i = 0 + while (b != vocabSize * 2 - 2) { + code(i) = binary(b) + point(i) = b + i += 1 + b = parentNode(b) + } + vocab(a).codeLen = i + vocab(a).point(0) = vocabSize - 2 + b = 0 + while (b < i) { + vocab(a).code(i - b - 1) = code(b) + vocab(a).point(i - b) = point(b) - vocabSize + b += 1 + } + a += 1 + } + } + + /** + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of sentences, + * each sentence is expressed as an iterable collection of words + * @return a Word2VecModel + */ + @Since("1.1.0") + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { + + learnVocab(dataset) + + createBinaryTree() + + val sc = dataset.context + + val expTable = sc.broadcast(createExpTable()) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) + try { + setParams(dataset.sparkContext.conf) + + val modelSGHS = new Word2VecSGHS(minCount, window, vectorSize, vocabSize, trainWordsCount, + learningRate, numIterations, seed, maxSentenceLength, regularization, repetition) + val vectors = modelSGHS.fit(dataset, expTable, bcVocab, bcVocabHash) + + new Word2VecModel(vocabHash.toMap, vectors) + } finally { + expTable.destroy(blocking = false) + bcVocab.destroy(blocking = false) + bcVocabHash.destroy(blocking = false) + } + } + + /** + * Computes the vector representation of each word in vocabulary (Java version). + * @param dataset a JavaRDD of words + * @return a Word2VecModel + */ + @Since("1.1.0") + def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { + fit(dataset.rdd.map(_.asScala)) + } +} + +object Word2Vec { + /** + * Get the float value from spark conf. + */ + def getSparkConfFloatValue( + conf: SparkConf, + valueName: String, + defaultValue: Float): Float = { + var value = defaultValue + val valueOption = conf.getOption(valueName) + + if(valueOption.nonEmpty) { + try { + value = valueOption.get.toFloat + } catch { + case _: Exception => + throw new IllegalArgumentException(s"$valueName should be Float, " + + s"but was ${valueOption.get}") + } + } + + value + } + + /** + * Get the int value from spark conf. + */ + def getSparkConfIntValue( + conf: SparkConf, + valueName: String, + defaultValue: Int): Int = { + var value = defaultValue + val valueOption = conf.getOption(valueName) + + if(valueOption.nonEmpty) { + try { + value = valueOption.get.toInt + } catch { + case _: Exception => + throw new IllegalArgumentException(s"$valueName should be Int and greater than 0, " + + s"but was ${valueOption.get}") + } + } + + value + } +} diff --git a/ml-core/pom.xml b/ml-core/pom.xml index 6aad5f95bcd1293ee567cdf2a63c457ab929c4e1..21634bd58dd05881068483dfdd69351e18908f9a 100644 --- a/ml-core/pom.xml +++ b/ml-core/pom.xml @@ -2,12 +2,12 @@ org.apache.spark boostkit-ml - 2.1.0 + 2.2.0 4.0.0 boostkit-ml-core_2.11 - 2.1.0 + 2.2.0 ${project.artifactId} Spark ml core @@ -15,7 +15,7 @@ org.apache.spark boostkit-ml-kernel-client-core_2.11 - 2.1.0 + 2.2.0 ${spark.version} compile @@ -67,6 +67,25 @@ + + org.codehaus.mojo + build-helper-maven-plugin + 3.0.0 + + + generate-sources + + add-source + + + + src/main/java + src/main/native + + + + + diff --git a/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..e6cc5fe0f6504378b383b16f3a3d13a9f35076e1 --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java @@ -0,0 +1,240 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib; + +public interface BLAS { + + public static BLAS getInstance() { + return InstanceBuilder.BLAS.getInstance(); + } + + public double dasum(int n, double[] x, int incx); + public double dasum(int n, double[] x, int offsetx, int incx); + + public float sasum(int n, float[] x, int incx); + public float sasum(int n, float[] x, int offsetx, int incx); + + public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy); + public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy); + public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public void dcopy(int n, double[] x, int incx, double[] y, int incy); + public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void scopy(int n, float[] x, int incx, float[] y, int incy); + public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public double ddot(int n, double[] x, int incx, double[] y, int incy); + public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public float sdot(int n, float[] x, int incx, float[] y, int incy); + public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public float sdsdot(int n, float sb, float[] sx, int incx, float[] sy, int incy); + public float sdsdot(int n, float sb, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy); + + public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy); + public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy); + public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc); + public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc); + public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc); + + public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy); + public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy); + public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda); + public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda); + public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + public double dnrm2(int n, double[] x, int incx); + public double dnrm2(int n, double[] x, int offsetx, int incx); + + public float snrm2(int n, float[] x, int incx); + public float snrm2(int n, float[] x, int offsetx, int incx); + + public void drot(int n, double[] dx, int incx, double[] dy, int incy, double c, double s); + public void drot(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s); + + public void srot(int n, float[] sx, int incx, float[] sy, int incy, float c, float s); + public void srot(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s); + + public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s); + + public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s); + + public void drotm(int n, double[] dx, int incx, double[] dy, int incy, double[] dparam); + public void drotm(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam); + + public void srotm(int n, float[] sx, int incx, float[] sy, int incy, float[] sparam); + public void srotm(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam); + + public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam); + public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam); + + public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam); + public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam); + + public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy); + public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy); + public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dscal(int n, double alpha, double[] x, int incx); + public void dscal(int n, double alpha, double[] x, int offsetx, int incx); + + public void sscal(int n, float alpha, float[] x, int incx); + public void sscal(int n, float alpha, float[] x, int offsetx, int incx); + + public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy); + public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sspmv(String uplo, int n, float alpha, float[] ap, float[] x, int incx, float beta, float[] y, int incy); + public void sspmv(String uplo, int n, float alpha, float[] ap, int offsetap, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a); + public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta); + + public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] ap); + public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] ap, int offsetap); + + public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] ap); + public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] ap, int offsetap); + + public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] ap); + public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] ap, int offsetap); + + public void dswap(int n, double[] x, int incx, double[] y, int incy); + public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void sswap(int n, float[] x, int incx, float[] y, int incy); + public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc); + public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc); + + public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc); + public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc); + + public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy); + public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy); + public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda); + public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda); + + public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda); + public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda); + + public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda); + public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda); + public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc); + public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc); + + public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc); + public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc); + + public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int Ldc); + public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int Ldc); + + public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int Ldc); + public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int Ldc); + + public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx); + public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx); + public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx); + public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx); + public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx); + public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx); + + public void stpmv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx); + public void stpmv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx); + + public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx); + public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx); + + public void stpsv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx); + public void stpsv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx); + + public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb); + public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb); + public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx); + public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx); + public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb); + public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb); + public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx); + public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx); + public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public int idamax(int n, double[] x, int incx); + public int idamax(int n, double[] x, int offsetx, int incx); + + public int isamax(int n, float[] sx, int incx); + public int isamax(int n, float[] sx, int offsetsx, int incx); + + public boolean lsame(String ca, String cb); +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java b/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..0d3eee6d2770f81cf4194aba30f92d73186ae7be --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib; + +import java.util.logging.Logger; + +final class InstanceBuilder { + + public static final class BLAS { + private static final dev.ludovic.netlib.BLAS instance = getInstanceImpl(); + + public static dev.ludovic.netlib.BLAS getInstance() { + return instance; + } + + private static dev.ludovic.netlib.BLAS getInstanceImpl() { + try { + return dev.ludovic.netlib.NativeBLAS.getInstance(); + } catch (Throwable t) { + Logger.getLogger(BLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.NativeBLAS.class.getName()); + } + return dev.ludovic.netlib.JavaBLAS.getInstance(); + } + } + + public static final class NativeBLAS { + private static final dev.ludovic.netlib.NativeBLAS instance = getInstanceImpl(); + + public static dev.ludovic.netlib.NativeBLAS getInstance() { + return instance; + } + + private static dev.ludovic.netlib.NativeBLAS getInstanceImpl() { + try { + return dev.ludovic.netlib.blas.JNIBLAS.getInstance(); + } catch (Throwable t) { + Logger.getLogger(NativeBLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.blas.JNIBLAS.class.getName()); + } + throw new RuntimeException("Unable to load native implementation"); + } + } + + public static final class JavaBLAS { + private static final dev.ludovic.netlib.JavaBLAS instance = getInstanceImpl(); + + public static dev.ludovic.netlib.JavaBLAS getInstance() { + return instance; + } + + private static dev.ludovic.netlib.JavaBLAS getInstanceImpl() { + return dev.ludovic.netlib.blas.Java8BLAS.getInstance(); + } + } +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..834aa3986348ad8f88690ac0f0948c8c2b439452 --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java @@ -0,0 +1,33 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib; + +public interface JavaBLAS extends BLAS { + + public static JavaBLAS getInstance() { + return InstanceBuilder.JavaBLAS.getInstance(); + } +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..a6fd83a4b2980a83b6e56cf08de2cfd27c165cfd --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java @@ -0,0 +1,33 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib; + +public interface NativeBLAS extends BLAS { + + public static NativeBLAS getInstance() { + return InstanceBuilder.NativeBLAS.getInstance(); + } +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..00f3c64c5e8135c7ce7f60c67bba5785c2f5467a --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java @@ -0,0 +1,1689 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib.blas; + +import java.util.Objects; + +import dev.ludovic.netlib.BLAS; + +abstract class AbstractBLAS implements BLAS { + + private final static boolean debug = System.getProperty("dev.ludovic.netlib.blas.debug", "false").equals("true"); + + protected int loopAlign(int index, int max, int size) { + return Math.min(loopBound(index + size - 1, size), max); + } + + protected int loopBound(int index, int size) { + return index - (index % size); + } + + private void checkArgument(String method, int arg, boolean check) { + if (!check) { + throw new IllegalArgumentException(String.format("** On entry to '%s' parameter number %d had an illegal value", method, arg)); + } + } + + private void checkIndex(int index, int length) { + //FIXME: switch to Objects.checkIndex when the minimum version becomes JDK 11 + if (index < 0 || index >= length) { + throw new IndexOutOfBoundsException(String.format("Index %s out of bounds for length %s", index, length)); + } + } + + private void requireNonNull(T obj) { + Objects.requireNonNull(obj); + } + + public double dasum(int n, double[] x, int incx) { + if (debug) System.err.println("dasum"); + return dasum(n, x, 0, incx); + } + + public double dasum(int n, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dasum"); + if (n <= 0) { + return 0.0; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + return dasumK(n, x, offsetx, incx); + } + + protected abstract double dasumK(int n, double[] x, int offsetx, int incx); + + public float sasum(int n, float[] x, int incx) { + if (debug) System.err.println("sasum"); + return sasum(n, x, 0, incx); + } + + public float sasum(int n, float[] x, int offsetx, int incx) { + if (debug) System.err.println("sasum"); + if (n <= 0) { + return 0.0f; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + return sasumK(n, x, offsetx, incx); + } + + protected abstract float sasumK(int n, float[] x, int offsetx, int incx); + + public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy) { + if (debug) System.err.println("daxpy"); + daxpy(n, alpha, x, 0, incx, y, 0, incy); + } + + // y += alpha * x + public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (debug) System.err.println("daxpy"); + if (n <= 0) { + return; + } + if (alpha == 0.0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + daxpyK(n, alpha, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy) { + if (debug) System.err.println("saxpy"); + saxpy(n, alpha, x, 0, incx, y, 0, incy); + } + + // y += alpha * x + public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (debug) System.err.println("saxpy"); + if (n <= 0) { + return; + } + if (alpha == 0.0f) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + saxpyK(n, alpha, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public void dcopy(int n, double[] x, int incx, double[] y, int incy) { + if (debug) System.err.println("dcopy"); + dcopy(n, x, 0, incx, y, 0, incy); + } + + public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (debug) System.err.println("dcopy"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + dcopyK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void scopy(int n, float[] x, int incx, float[] y, int incy) { + if (debug) System.err.println("scopy"); + scopy(n, x, 0, incx, y, 0, incy); + } + + public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (debug) System.err.println("scopy"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + scopyK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public double ddot(int n, double[] x, int incx, double[] y, int incy) { + if (debug) System.err.println("ddot"); + return ddot(n, x, 0, incx, y, 0, incy); + } + + // sum(x * y) + public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (debug) System.err.println("ddot"); + if (n <= 0) { + return 0.0; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + return ddotK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public float sdot(int n, float[] x, int incx, float[] y, int incy) { + if (debug) System.err.println("sdot"); + return sdot(n, x, 0, incx, y, 0, incy); + } + + // sum(x * y) + public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (debug) System.err.println("sdot"); + if (n <= 0) { + return 0.0f; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + return sdotK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public float sdsdot(int n, float sb, float[] x, int incx, float[] y, int incy) { + if (debug) System.err.println("sdsdot"); + return sdsdot(n, sb, x, 0, incx, y, 0, incy); + } + + public float sdsdot(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (debug) System.err.println("sdsdot"); + if (n <= 0) { + return 0.0f; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + return sdsdotK(n, sb, x, offsetx, incx, y, offsety, incy); + } + + protected abstract float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) { + if (debug) System.err.println("dgbmv"); + dgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (debug) System.err.println("dgbmv"); + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); + checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); + dgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) { + if (debug) System.err.println("sgbmv"); + sgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (debug) System.err.println("sgbmv"); + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); + checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); + sgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) { + if (debug) System.err.println("dgemm"); + dgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + // c = alpha * a * b + beta * c + public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + if (debug) System.err.println("dgemm"); + checkArgument("DGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa)); + checkArgument("DGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb)); + checkArgument("DGEMM", 3, m >= 0); + checkArgument("DGEMM", 4, n >= 0); + checkArgument("DGEMM", 5, k >= 0); + checkArgument("DGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k)); + checkArgument("DGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n)); + checkArgument("DGEMM", 13, ldc >= Math.max(1, m)); + if (m == 0 || n == 0 || ((alpha == 0.0 || k == 0) && beta == 1.0)) { + return; + } + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length); + checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length); + checkIndex(offsetc + m * n - 1, c.length); + dgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) { + if (debug) System.err.println("sgemm"); + sgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + // c = alpha * a * b + beta * c + public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + if (debug) System.err.println("sgemm"); + checkArgument("SGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa)); + checkArgument("SGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb)); + checkArgument("SGEMM", 3, m >= 0); + checkArgument("SGEMM", 4, n >= 0); + checkArgument("SGEMM", 5, k >= 0); + checkArgument("SGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k)); + checkArgument("SGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n)); + checkArgument("SGEMM", 13, ldc >= Math.max(1, m)); + if (m == 0 || n == 0 || ((alpha == 0.0f || k == 0) && beta == 1.0f)) { + return; + } + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length); + checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length); + checkIndex(offsetc + m * n - 1, c.length); + sgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) { + if (debug) System.err.println("dgemv"); + dgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + // y = alpha * A * x + beta * y + public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (debug) System.err.println("dgemv"); + checkArgument("DGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DGEMV", 2, m >= 0); + checkArgument("DGEMV", 3, n >= 0); + checkArgument("DGEMV", 6, lda >= Math.max(1, m)); + checkArgument("DGEMV", 8, incx != 0); + checkArgument("DGEMV", 11, incy != 0); + if (m == 0 || n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); + checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); + dgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) { + if (debug) System.err.println("sgemv"); + sgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + // y = alpha * A * x + beta * y + public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (debug) System.err.println("sgemv"); + checkArgument("SGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("SGEMV", 2, m >= 0); + checkArgument("SGEMV", 3, n >= 0); + checkArgument("SGEMV", 6, lda >= Math.max(1, m)); + checkArgument("SGEMV", 8, incx != 0); + checkArgument("SGEMV", 11, incy != 0); + if (m == 0 || n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); + checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); + sgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + // A += alpha * x * y.t + public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) { + if (debug) System.err.println("dger"); + dger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda); + } + + public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + if (debug) System.err.println("dger"); + checkArgument("DGER", 1, m >= 0); + checkArgument("DGER", 2, n >= 0); + checkArgument("DGER", 5, incx != 0); + checkArgument("DGER", 7, incy != 0); + checkArgument("DGER", 9, lda >= Math.max(1, m)); + if (m == 0 || n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + n * lda - 1, a.length); + if (alpha != 0.0) { + dgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + } + + protected abstract void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) { + if (debug) System.err.println("sger"); + sger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda); + } + + public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + if (debug) System.err.println("sger"); + checkArgument("SGER", 1, m >= 0); + checkArgument("SGER", 2, n >= 0); + checkArgument("SGER", 5, incx != 0); + checkArgument("SGER", 7, incy != 0); + checkArgument("SGER", 9, lda >= Math.max(1, m)); + if (m == 0 || n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + n * lda - 1, a.length); + if (alpha != 0.0f) { + sgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + } + + protected abstract void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + public double dnrm2(int n, double[] x, int incx) { + if (debug) System.err.println("dnrm2"); + return dnrm2(n, x, 0, incx); + } + + public double dnrm2(int n, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dnrm2"); + if (n <= 0) { + return 0.0; + } + if (incx <= 0) { + return 0.0; + } + if (n == 1) { + return Math.abs(x[offsetx + 0]); + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + return dnrm2K(n, x, offsetx, incx); + } + + protected abstract double dnrm2K(int n, double[] x, int offsetx, int incx); + + public float snrm2(int n, float[] x, int incx) { + if (debug) System.err.println("snrm2"); + return snrm2(n, x, 0, incx); + } + + public float snrm2(int n, float[] x, int offsetx, int incx) { + if (debug) System.err.println("snrm2"); + if (n <= 0) { + return 0.0f; + } + if (incx <= 0) { + return 0.0f; + } + if (n == 1) { + return Math.abs(x[offsetx + 0]); + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + return snrm2K(n, x, offsetx, incx); + } + + protected abstract float snrm2K(int n, float[] x, int offsetx, int incx); + + public void drot(int n, double[] x, int incx, double[] y, int incy, double c, double s) { + if (debug) System.err.println("drot"); + drot(n, x, 0, incx, y, 0, incy, c, s); + } + + public void drot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) { + if (debug) System.err.println("drot"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + drotK(n, x, offsetx, incx, y, offsety, incy, c, s); + } + + protected abstract void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s); + + public void srot(int n, float[] x, int incx, float[] y, int incy, float c, float s) { + if (debug) System.err.println("srot"); + srot(n, x, 0, incx, y, 0, incy, c, s); + } + + public void srot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) { + if (debug) System.err.println("srot"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + srotK(n, x, offsetx, incx, y, offsety, incy, c, s); + } + + protected abstract void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s); + + public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s) { + if (debug) System.err.println("drotg"); + double scale = Math.abs(da.val) + Math.abs(db.val); + if (scale == 0.0) { + c.val = 1.0; + s.val = 0.0; + da.val = 0.0; + db.val = 0.0; + } else { + double r = scale * Math.sqrt(Math.pow(da.val / scale, 2) + Math.pow(db.val / scale, 2)) + * ((Math.abs(da.val) > Math.abs(db.val) ? da.val : db.val) >= 0.0 ? 1.0 : -1.0); + c.val = da.val / r; + s.val = db.val / r; + double z = 1.0; + if (Math.abs(da.val) > Math.abs(db.val)) { + z = s.val; + } else if (c.val != 0.0) { + z = 1.0 / c.val; + } + da.val = r; + db.val = z; + } + } + + public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s) { + if (debug) System.err.println("srotg"); + float scale = Math.abs(sa.val) + Math.abs(sb.val); + if (scale == 0.0f) { + c.val = 1.0f; + s.val = 0.0f; + sa.val = 0.0f; + sb.val = 0.0f; + } else { + float r = (float)(scale * Math.sqrt(Math.pow(sa.val / scale, 2) + Math.pow(sb.val / scale, 2)) + * ((Math.abs(sa.val) > Math.abs(sb.val) ? sa.val : sb.val) >= 0.0f ? 1.0 : -1.0)); + c.val = sa.val / r; + s.val = sb.val / r; + float z = 1.0f; + if (Math.abs(sa.val) > Math.abs(sb.val)) { + z = s.val; + } else if (c.val != 0.0f) { + z = 1.0f / c.val; + } + sa.val = r; + sb.val = z; + } + } + + public void drotm(int n, double[] x, int incx, double[] y, int incy, double[] param) { + if (debug) System.err.println("drotm"); + drotm(n, x, 0, incx, y, 0, incy, param, 0); + } + + public void drotm(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) { + if (debug) System.err.println("drotm"); + requireNonNull(x); + requireNonNull(y); + requireNonNull(param); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offsetparam + 4, param.length); /* param.length == 5 */ + drotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + + protected abstract void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam); + + public void srotm(int n, float[] x, int incx, float[] y, int incy, float[] param) { + if (debug) System.err.println("srotm"); + srotm(n, x, 0, incx, y, 0, incy, param, 0); + } + + public void srotm(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) { + if (debug) System.err.println("srotm"); + requireNonNull(x); + requireNonNull(y); + requireNonNull(param); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offsetparam + 4, param.length); /* param.length == 5 */ + srotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + + protected abstract void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam); + + public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param) { + if (debug) System.err.println("drotmg"); + drotmg(dd1, dd2, dx1, dy1, param, 0); + } + + public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) { + if (debug) System.err.println("drotmg"); + requireNonNull(dd1); + requireNonNull(dd2); + requireNonNull(dx1); + requireNonNull(param); + checkIndex(offsetparam + 4, param.length); + drotmgK(dd1, dd2, dx1, dy1, param, offsetparam); + } + + protected abstract void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam); + + public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param) { + if (debug) System.err.println("srotmg"); + srotmg(sd1, sd2, sx1, sy1, param, 0); + } + + public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) { + if (debug) System.err.println("srotmg"); + requireNonNull(sd1); + requireNonNull(sd2); + requireNonNull(sx1); + requireNonNull(param); + checkIndex(offsetparam + 4, param.length); + srotmgK(sd1, sd2, sx1, sy1, param, offsetparam); + } + + protected abstract void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam); + + public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) { + if (debug) System.err.println("dsbmv"); + dsbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (debug) System.err.println("dsbmv"); + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + dsbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) { + if (debug) System.err.println("ssbmv"); + ssbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (debug) System.err.println("ssbmv"); + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + ssbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dscal(int n, double alpha, double[] x, int incx) { + if (debug) System.err.println("dscal"); + dscal(n, alpha, x, 0, incx); + } + + // x = alpha * x + public void dscal(int n, double alpha, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dscal"); + if (n <= 0) { + return; + } + if (incx <= 0) { + return; + } + if (alpha == 1.0) { + return; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dscalK(n, alpha, x, offsetx, incx); + } + + protected abstract void dscalK(int n, double alpha, double[] x, int offsetx, int incx); + + public void sscal(int n, float alpha, float[] x, int incx) { + if (debug) System.err.println("sscal"); + sscal(n, alpha, x, 0, incx); + } + + // x = alpha * x + public void sscal(int n, float alpha, float[] x, int offsetx, int incx) { + if (debug) System.err.println("sscal"); + if (n <= 0) { + return; + } + if (incx <= 0) { + return; + } + if (alpha == 1.0f) { + return; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + sscalK(n, alpha, x, offsetx, incx); + } + + protected abstract void sscalK(int n, float alpha, float[] x, int offsetx, int incx); + + public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy) { + if (debug) System.err.println("dspmv"); + dspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy); + } + + // y = alpha * a * x + beta * y + public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (debug) System.err.println("dspmv"); + checkArgument("DSPMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSPMV", 2, n >= 0); + checkArgument("DSPMV", 6, incx != 0); + checkArgument("DSPMV", 9, incy != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + dspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void sspmv(String uplo, int n, float alpha, float[] a, float[] x, int incx, float beta, float[] y, int incy) { + if (debug) System.err.println("sspmv"); + sspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy); + } + + public void sspmv(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (debug) System.err.println("sspmv"); + checkArgument("SSPMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSPMV", 2, n >= 0); + checkArgument("SSPMV", 6, incx != 0); + checkArgument("SSPMV", 9, incy != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + sspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a) { + if (debug) System.err.println("dspr"); + dspr(uplo, n, alpha, x, 0, incx, a, 0); + } + + // a += alpha * x * x.t + public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) { + if (debug) System.err.println("dspr"); + checkArgument("DSPR", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSPR", 2, n >= 0); + checkArgument("DSPR", 5, incx != 0); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + dsprK(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + + protected abstract void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta); + + public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] a) { + if (debug) System.err.println("sspr"); + sspr(uplo, n, alpha, x, 0, incx, a, 0); + } + + // a += alpha * x * x.t + public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) { + if (debug) System.err.println("sspr"); + checkArgument("SSPR", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSPR", 2, n >= 0); + checkArgument("SSPR", 5, incx != 0); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + ssprK(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + + protected abstract void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta); + + public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a) { + if (debug) System.err.println("dspr2"); + dspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0); + } + + // a += alpha * x * y.t + alpha * y * x.t + public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) { + if (debug) System.err.println("dspr2"); + checkArgument("DSPR2", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSPR2", 2, n >= 0); + checkArgument("DSPR2", 5, incx != 0); + checkArgument("DSPR2", 7, incy != 0); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + dspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + + protected abstract void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta); + + public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a) { + if (debug) System.err.println("sspr2"); + sspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0); + } + + // a += alpha * x * y.t + alpha * y * x.t + public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) { + if (debug) System.err.println("sspr2"); + checkArgument("SSPR2", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSPR2", 2, n >= 0); + checkArgument("SSPR2", 5, incx != 0); + checkArgument("SSPR2", 7, incy != 0); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length); + sspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + + protected abstract void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta); + + public void dswap(int n, double[] x, int incx, double[] y, int incy) { + if (debug) System.err.println("dswap"); + dswap(n, x, 0, incx, y, 0, incy); + } + + public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (debug) System.err.println("dswap"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + dswapK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + public void sswap(int n, float[] x, int incx, float[] y, int incy) { + if (debug) System.err.println("sswap"); + sswap(n, x, 0, incx, y, 0, incy); + } + + public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (debug) System.err.println("sswap"); + if (n <= 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + sswapK(n, x, offsetx, incx, y, offsety, incy); + } + + protected abstract void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) { + if (debug) System.err.println("dsymm"); + dsymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + if (debug) System.err.println("dsymm"); + checkArgument("DSYMM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("DSYMM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYMM", 3, m >= 0); + checkArgument("DSYMM", 4, n >= 0); + checkArgument("DSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("DSYMM", 9, ldb >= Math.max(1, m)); + checkArgument("DSYMM", 12, ldc >= Math.max(1, m)); + if (m == 0 || n == 0 || (alpha == 0.0 && beta == 1.0)) { + return; + } + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + checkIndex(offsetc + n * ldc - 1, c.length); + dsymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) { + if (debug) System.err.println("ssymm"); + ssymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + if (debug) System.err.println("ssymm"); + checkArgument("SSYMM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("SSYMM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYMM", 3, m >= 0); + checkArgument("SSYMM", 4, n >= 0); + checkArgument("SSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("SSYMM", 9, ldb >= Math.max(1, m)); + checkArgument("SSYMM", 12, ldc >= Math.max(1, m)); + if (m == 0 || n == 0 || (alpha == 0.0f && beta == 1.0f)) { + return; + } + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + checkIndex(offsetc + n * ldc - 1, c.length); + ssymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) { + if (debug) System.err.println("dsymv"); + dsymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (debug) System.err.println("dsymv"); + checkArgument("DSYMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYMV", 2, n >= 0); + checkArgument("DSYMV", 5, lda >= Math.max(1, n)); + checkArgument("DSYMV", 7, incx != 0); + checkArgument("DSYMV", 10, incy != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + dsymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) { + if (debug) System.err.println("ssymv"); + ssymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); + } + + public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (debug) System.err.println("ssymv"); + checkArgument("SSYMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYMV", 2, n >= 0); + checkArgument("SSYMV", 5, lda >= Math.max(1, n)); + checkArgument("SSYMV", 7, incx != 0); + checkArgument("SSYMV", 10, incy != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + requireNonNull(y); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + ssymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected abstract void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda) { + if (debug) System.err.println("dsyr"); + dsyr(uplo, n, alpha, x, 0, incx, a, 0, lda); + } + + // a += alpha * x * x.t + public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) { + if (debug) System.err.println("dsyr"); + checkArgument("DSYR", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYR", 2, n >= 0); + checkArgument("DSYR", 5, incx != 0); + checkArgument("DSYR", 7, lda >= Math.max(1, n)); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offseta + n * lda - 1, a.length); + dsyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + + protected abstract void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda); + + public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda) { + if (debug) System.err.println("ssyr"); + ssyr(uplo, n, alpha, x, 0, incx, a, 0, lda); + } + + public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) { + if (debug) System.err.println("ssyr"); + checkArgument("SSYR", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYR", 2, n >= 0); + checkArgument("SSYR", 5, incx != 0); + checkArgument("SSYR", 7, lda >= Math.max(1, n)); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offseta + n * lda - 1, a.length); + ssyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + + protected abstract void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda); + + public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) { + if (debug) System.err.println("dsyr2"); + dsyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda); + } + + public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + if (debug) System.err.println("dsyr2"); + checkArgument("DSYR2", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYR2", 2, n >= 0); + checkArgument("DSYR2", 5, incx != 0); + checkArgument("DSYR2", 7, incy != 0); + checkArgument("DSYR2", 9, lda >= Math.max(1, n)); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + n * lda - 1, a.length); + dsyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + + protected abstract void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) { + if (debug) System.err.println("ssyr2"); + ssyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda); + } + + public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + if (debug) System.err.println("ssyr2"); + checkArgument("SSYR2", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYR2", 2, n >= 0); + checkArgument("SSYR2", 5, incx != 0); + checkArgument("SSYR2", 7, incy != 0); + checkArgument("SSYR2", 9, lda >= Math.max(1, n)); + if (n == 0) { + return; + } + requireNonNull(x); + requireNonNull(y); + requireNonNull(a); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + checkIndex(offsety + (n - 1) * Math.abs(incy), y.length); + checkIndex(offseta + n * lda - 1, a.length); + ssyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + + protected abstract void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) { + if (debug) System.err.println("dsyr2k"); + dsyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + if (debug) System.err.println("dsyr2k"); + checkArgument("DSYR2K", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DSYR2K", 3, n >= 0); + checkArgument("DSYR2K", 4, k >= 0); + checkArgument("DSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("DSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("DSYR2K", 12, ldc >= Math.max(1, n)); + if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0)) + return; + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length); + checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length); + checkIndex(offsetc + n * ldc - 1, c.length); + dsyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) { + if (debug) System.err.println("ssyr2k"); + ssyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); + } + + public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + if (debug) System.err.println("ssyr2k"); + checkArgument("SSYR2K", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("SSYR2K", 3, n >= 0); + checkArgument("SSYR2K", 4, k >= 0); + checkArgument("SSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("SSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("SSYR2K", 12, ldc >= Math.max(1, n)); + if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f)) + return; + requireNonNull(a); + requireNonNull(b); + requireNonNull(c); + checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length); + checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length); + checkIndex(offsetc + n * ldc - 1, c.length); + ssyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected abstract void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int ldc) { + if (debug) System.err.println("dsyrk"); + dsyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc); + } + + public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) { + if (debug) System.err.println("dsyrk"); + checkArgument("DSYRK", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DSYRK", 3, n >= 0); + checkArgument("DSYRK", 4, k >= 0); + checkArgument("DSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("DSYRK", 10, ldc >= Math.max(1, n)); + if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0)) + return; + requireNonNull(a); + requireNonNull(c); + checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length); + checkIndex(offsetc + n * ldc - 1, c.length); + dsyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + + protected abstract void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc); + + public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int ldc) { + if (debug) System.err.println("ssyrk"); + ssyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc); + } + + public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) { + if (debug) System.err.println("ssyrk"); + checkArgument("SSYRK", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("SSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("SSYRK", 3, n >= 0); + checkArgument("SSYRK", 4, k >= 0); + checkArgument("SSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k)); + checkArgument("SSYRK", 10, ldc >= Math.max(1, n)); + if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f)) + return; + requireNonNull(a); + requireNonNull(c); + checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length); + checkIndex(offsetc + n * ldc - 1, c.length); + ssyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + + protected abstract void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc); + + public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) { + if (debug) System.err.println("dtbmv"); + dtbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx); + } + + public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtbmv"); + checkArgument("DTBMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTBMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTBMV", 4, n >= 0); + checkArgument("DTBMV", 5, k >= 0); + checkArgument("DTBMV", 7, lda >= Math.max(1, k)); + checkArgument("DTBMV", 9, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) { + if (debug) System.err.println("stbmv"); + stbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx); + } + + public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + if (debug) System.err.println("stbmv"); + checkArgument("STBMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STBMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STBMV", 4, n >= 0); + checkArgument("STBMV", 5, k >= 0); + checkArgument("STBMV", 7, lda >= Math.max(1, k)); + checkArgument("STBMV", 9, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + stbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) { + if (debug) System.err.println("dtbsv"); + dtbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx); + } + + public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtbsv"); + checkArgument("DTBSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTBSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTBSV", 4, n >= 0); + checkArgument("DTBSV", 5, k >= 0); + checkArgument("DTBSV", 7, lda >= Math.max(1, k)); + checkArgument("DTBSV", 9, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) { + if (debug) System.err.println("stbsv"); + stbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx); + } + + public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + if (debug) System.err.println("stbsv"); + checkArgument("STBSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STBSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STBSV", 4, n >= 0); + checkArgument("STBSV", 5, k >= 0); + checkArgument("STBSV", 7, lda >= Math.max(1, k)); + checkArgument("STBSV", 9, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + stbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtpmv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) { + if (debug) System.err.println("dtpmv"); + dtpmv(uplo, trans, diag, n, a, 0, x, 0, incx); + } + + public void dtpmv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtpmv"); + checkArgument("DTPMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTPMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTPMV", 4, n >= 0); + checkArgument("DTPMV", 7, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * (n + 1) / 2 - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected abstract void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx); + + public void stpmv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) { + if (debug) System.err.println("stpmv"); + stpmv(uplo, trans, diag, n, a, 0, x, 0, incx); + } + + public void stpmv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + if (debug) System.err.println("stpmv"); + checkArgument("STPMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STPMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STPMV", 4, n >= 0); + checkArgument("STPMV", 7, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * (n + 1) / 2 - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + stpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected abstract void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx); + + public void dtpsv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) { + if (debug) System.err.println("dtpsv"); + dtpsv(uplo, trans, diag, n, a, 0, x, 0, incx); + } + + public void dtpsv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtpsv"); + checkArgument("DTPSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTPSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTPSV", 4, n >= 0); + checkArgument("DTPSV", 7, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * (n + 1) / 2 - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected abstract void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx); + + public void stpsv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) { + if (debug) System.err.println("stpsv"); + stpsv(uplo, trans, diag, n, a, 0, x, 0, incx); + } + + public void stpsv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + if (debug) System.err.println("stpsv"); + checkArgument("STPSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STPSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STPSV", 4, n >= 0); + checkArgument("STPSV", 7, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * (n + 1) / 2 - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + stpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected abstract void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx); + + public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) { + if (debug) System.err.println("dtrmm"); + dtrmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb); + } + + public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + if (debug) System.err.println("dtrmm"); + checkArgument("DTRMM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("DTRMM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa)); + checkArgument("DTRMM", 4, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTRMM", 5, m >= 0); + checkArgument("DTRMM", 6, n >= 0); + checkArgument("DTRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("DTRMM", 11, ldb >= Math.max(1, m)); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(b); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + dtrmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected abstract void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) { + if (debug) System.err.println("strmm"); + strmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb); + } + + public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + if (debug) System.err.println("strmm"); + checkArgument("STRMM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("STRMM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa)); + checkArgument("STRMM", 4, lsame("U", diag) || lsame("N", diag)); + checkArgument("STRMM", 5, m >= 0); + checkArgument("STRMM", 6, n >= 0); + checkArgument("STRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("STRMM", 11, ldb >= Math.max(1, m)); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(b); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + strmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected abstract void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) { + if (debug) System.err.println("dtrmv"); + dtrmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx); + } + + public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtrmv"); + checkArgument("DTRMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTRMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTRMV", 4, n >= 0); + checkArgument("DTRMV", 6, lda >= Math.max(1, n)); + checkArgument("DTRMV", 8, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtrmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) { + if (debug) System.err.println("strmv"); + strmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx); + } + + public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + if (debug) System.err.println("strmv"); + checkArgument("STRMV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STRMV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STRMV", 4, n >= 0); + checkArgument("STRMV", 6, lda >= Math.max(1, n)); + checkArgument("STRMV", 8, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + strmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) { + if (debug) System.err.println("dtrsm"); + dtrsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb); + } + + public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + if (debug) System.err.println("dtrsm"); + checkArgument("DTRSM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("DTRSM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa)); + checkArgument("DTRSM", 4, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTRSM", 5, m >= 0); + checkArgument("DTRSM", 6, n >= 0); + checkArgument("DTRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("DTRSM", 11, ldb >= Math.max(1, m)); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(b); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + dtrsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected abstract void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) { + if (debug) System.err.println("strsm"); + strsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb); + } + + public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + if (debug) System.err.println("strsm"); + checkArgument("STRSM", 1, lsame("L", side) || lsame("R", side)); + checkArgument("STRSM", 2, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa)); + checkArgument("STRSM", 4, lsame("U", diag) || lsame("N", diag)); + checkArgument("STRSM", 5, m >= 0); + checkArgument("STRSM", 6, n >= 0); + checkArgument("STRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n)); + checkArgument("STRSM", 11, ldb >= Math.max(1, m)); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(b); + checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length); + checkIndex(offsetb + n * ldb - 1, b.length); + strsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected abstract void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) { + if (debug) System.err.println("dtrsv"); + dtrsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx); + } + + public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + if (debug) System.err.println("dtrsv"); + checkArgument("DTRSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("DTRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("DTRSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("DTRSV", 4, n >= 0); + checkArgument("DTRSV", 6, lda >= Math.max(1, n)); + checkArgument("DTRSV", 8, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + dtrsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) { + if (debug) System.err.println("strsv"); + strsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx); + } + + public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + if (debug) System.err.println("strsv"); + checkArgument("STRSV", 1, lsame("U", uplo) || lsame("L", uplo)); + checkArgument("STRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans)); + checkArgument("STRSV", 3, lsame("U", diag) || lsame("N", diag)); + checkArgument("STRSV", 4, n >= 0); + checkArgument("STRSV", 6, lda >= Math.max(1, n)); + checkArgument("STRSV", 8, incx != 0); + if (n == 0) { + return; + } + requireNonNull(a); + requireNonNull(x); + checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + strsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected abstract void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + public int idamax(int n, double[] x, int incx) { + if (debug) System.err.println("idamax"); + return idamax(n, x, 0, incx); + } + + public int idamax(int n, double[] x, int offsetx, int incx) { + if (debug) System.err.println("idamax"); + if (n <= 0) { + return -1; + } + if (incx <= 0) { + return -1; + } + if (n == 1) { + return 0; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + // Fortran arrays use 1-based index + return idamaxK(n, x, offsetx, incx) - 1; + } + + protected abstract int idamaxK(int n, double[] x, int offsetx, int incx); + + public int isamax(int n, float[] x, int incx) { + if (debug) System.err.println("isamax"); + return isamax(n, x, 0, incx); + } + + public int isamax(int n, float[] x, int offsetx, int incx) { + if (debug) System.err.println("isamax"); + if (n <= 0) { + return -1; + } + if (incx <= 0) { + return -1; + } + if (n == 1) { + return 0; + } + requireNonNull(x); + checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length); + // Fortran arrays use 1-based index + return isamaxK(n, x, offsetx, incx) - 1; + } + + protected abstract int isamaxK(int n, float[] x, int offsetx, int incx); + + public boolean lsame(String ca, String cb) { + if (debug) System.err.println("lsame"); + return ca != null && ca.regionMatches(true, 0, cb, 0, ca.length()); + } +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..d0a8dab6a7b62edb52772167a7cd272bb7f076be --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java @@ -0,0 +1,241 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib.blas; + +import dev.ludovic.netlib.BLAS; + +public final class F2jBLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS { + + private static final F2jBLAS instance = new F2jBLAS(); + + protected F2jBLAS() {} + + public static dev.ludovic.netlib.JavaBLAS getInstance() { + return instance; + } + + protected double dasumK(int n, double[] x, int offsetx, int incx) { + return org.netlib.blas.Dasum.dasum(n, x, offsetx, incx); + } + protected float sasumK(int n, float[] x, int offsetx, int incx) { + return org.netlib.blas.Sasum.sasum(n, x, offsetx, incx); + } + protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + org.netlib.blas.Daxpy.daxpy(n, alpha, x, offsetx, incx, y, offsety, incy); + } + protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + org.netlib.blas.Saxpy.saxpy(n, alpha, x, offsetx, incx, y, offsety, incy); + } + protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + org.netlib.blas.Dcopy.dcopy(n, x, offsetx, incx, y, offsety, incy); + } + protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + org.netlib.blas.Scopy.scopy(n, x, offsetx, incx, y, offsety, incy); + } + protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + return org.netlib.blas.Ddot.ddot(n, x, offsetx, incx, y, offsety, incy); + } + protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + return org.netlib.blas.Sdot.sdot(n, x, offsetx, incx, y, offsety, incy); + } + protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + return org.netlib.blas.Sdsdot.sdsdot(n, sb, x, offsetx, incx, y, offsety, incy); + } + protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dgemm.dgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Sgemm.sgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dgemv.dgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Sgemv.sgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + org.netlib.blas.Dger.dger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + org.netlib.blas.Sger.sger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + protected double dnrm2K(int n, double[] x, int offsetx, int incx) { + return org.netlib.blas.Dnrm2.dnrm2(n, x, offsetx, incx); + } + protected float snrm2K(int n, float[] x, int offsetx, int incx) { + return org.netlib.blas.Snrm2.snrm2(n, x, offsetx, incx); + } + protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) { + org.netlib.blas.Drot.drot(n, x, offsetx, incx, y, offsety, incy, c, s); + } + protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) { + org.netlib.blas.Srot.srot(n, x, offsetx, incx, y, offsety, incy, c, s); + } + protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) { + org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) { + org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) { + org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam); + } + protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) { + org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam); + } + protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) { + org.netlib.blas.Dscal.dscal(n, alpha, x, offsetx, incx); + } + protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) { + org.netlib.blas.Sscal.sscal(n, alpha, x, offsetx, incx); + } + protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dspmv.dspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Sspmv.sspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) { + org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) { + org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) { + org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) { + org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + org.netlib.blas.Dswap.dswap(n, x, offsetx, incx, y, offsety, incy); + } + protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + org.netlib.blas.Sswap.sswap(n, x, offsetx, incx, y, offsety, incy); + } + protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dsymm.dsymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Ssymm.ssymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dsymv.dsymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Ssymv.ssymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) { + org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) { + org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + protected int idamaxK(int n, double[] x, int offsetx, int incx) { + return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx); + } + protected int isamaxK(int n, float[] x, int offsetx, int incx) { + return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx); + } +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..34378855fbe88de4c153ad9c1dfd7d2afb0e5805 --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java @@ -0,0 +1,201 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib.blas; + +import java.io.InputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.nio.file.attribute.PosixFilePermissions; + +public final class JNIBLAS extends AbstractBLAS implements dev.ludovic.netlib.NativeBLAS { + + private static final JNIBLAS instance = new JNIBLAS(); + + protected JNIBLAS() { + String osName = System.getProperty("os.name"); + if (osName == null || osName.isEmpty()) { + throw new RuntimeException("Unable to load native implementation"); + } + String osArch = System.getProperty("os.arch"); + if (osArch == null || osArch.isEmpty()) { + throw new RuntimeException("Unable to load native implementation"); + } + + Path temp; + try (InputStream resource = this.getClass().getClassLoader().getResourceAsStream( + String.format("resources/native/%s-%s/libnetlibblasjni.so", osName, osArch))) { + assert resource != null; + Files.copy(resource, temp = Files.createTempFile("libnetlibblasjni.so", "", + PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxr-x---"))), + StandardCopyOption.REPLACE_EXISTING); + temp.toFile().deleteOnExit(); + } catch (IOException e) { + throw new RuntimeException("Unable to load native implementation", e); + } + + System.load(temp.toString()); + } + + public static dev.ludovic.netlib.NativeBLAS getInstance() { + return instance; + } + + protected native double dasumK(int n, double[] x, int offsetx, int incx); + + protected native float sasumK(int n, float[] x, int offsetx, int incx); + + protected native void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + protected native void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + protected native void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + protected native void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + protected native double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + protected native float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + protected native float sdsdotK(int n, float sb, float[] sx, int offsetsx, int incsx, float[] sy, int offsetsy, int incsy); + + protected native void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + protected native void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + protected native void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + protected native void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + protected native void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + protected native void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + protected native void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + protected native void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + protected native double dnrm2K(int n, double[] x, int offsetx, int incx); + + protected native float snrm2K(int n, float[] x, int offsetx, int incx); + + protected native void drotK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s); + + protected native void srotK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s); + + protected native void drotmK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam); + + protected native void srotmK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam); + + protected native void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam); + + protected native void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam); + + protected native void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + protected native void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + protected native void dscalK(int n, double alpha, double[] x, int offsetx, int incx); + + protected native void sscalK(int n, float alpha, float[] x, int offsetx, int incx); + + protected native void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + protected native void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + protected native void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta); + + protected native void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta); + + protected native void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta); + + protected native void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta); + + protected native void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy); + + protected native void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy); + + protected native void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + protected native void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + protected native void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy); + + protected native void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy); + + protected native void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda); + + protected native void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda); + + protected native void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda); + + protected native void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda); + + protected native void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc); + + protected native void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc); + + protected native void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc); + + protected native void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc); + + protected native void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + protected native void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + protected native void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + protected native void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + protected native void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx); + + protected native void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx); + + protected native void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx); + + protected native void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx); + + protected native void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + protected native void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + protected native void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + protected native void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + protected native void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb); + + protected native void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb); + + protected native void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx); + + protected native void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx); + + protected native int idamaxK(int n, double[] dx, int offsetdx, int incdx); + + protected native int isamaxK(int n, float[] sx, int offsetsx, int incx); +} diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java new file mode 100644 index 0000000000000000000000000000000000000000..443a6329f32a2cd7e70144b87328dd752230eaf9 --- /dev/null +++ b/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java @@ -0,0 +1,5157 @@ +/* + * Copyright 2020, 2021, Ludovic Henry + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please contact git@ludovic.dev or visit ludovic.dev if you need additional + * information or have any questions. + */ + +package dev.ludovic.netlib.blas; + +import dev.ludovic.netlib.BLAS; + +public class Java8BLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS { + + private static final Java8BLAS instance = new Java8BLAS(); + + protected Java8BLAS() {} + + public static dev.ludovic.netlib.JavaBLAS getInstance() { + return instance; + } + + protected double dasumK(int n, double[] x, int offsetx, int incx) { + double sum = 0.0; + if (incx == 1) { + int ix = 0; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (; ix < loopBound(n, 4); ix += 4) { + sum0 += Math.abs(x[offsetx + ix + 0]); + sum1 += Math.abs(x[offsetx + ix + 1]); + sum2 += Math.abs(x[offsetx + ix + 2]); + sum3 += Math.abs(x[offsetx + ix + 3]); + } + sum += sum0 + sum1 + sum2 + sum3; + for (; ix < n; ix += 1) { + sum += Math.abs(x[offsetx + ix]); + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) { + sum += Math.abs(x[offsetx + ix]); + } + } + return sum; + } + + protected float sasumK(int n, float[] x, int offsetx, int incx) { + float sum = 0.0f; + if (incx == 1) { + int ix = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (; ix < loopBound(n, 4); ix += 4) { + sum0 += Math.abs(x[offsetx + ix + 0]); + sum1 += Math.abs(x[offsetx + ix + 1]); + sum2 += Math.abs(x[offsetx + ix + 2]); + sum3 += Math.abs(x[offsetx + ix + 3]); + } + sum += sum0 + sum1 + sum2 + sum3; + for (; ix < n; ix += 1) { + sum += Math.abs(x[offsetx + ix]); + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) { + sum += Math.abs(x[offsetx + ix]); + } + } + return sum; + } + + protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + y[offsety + iy] += alpha * x[offsetx + ix]; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + y[offsety + iy] += alpha * x[offsetx + ix]; + } + } + } + + protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + y[offsety + iy] += alpha * x[offsetx + ix]; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + y[offsety + iy] += alpha * x[offsetx + ix]; + } + } + } + + protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + System.arraycopy(x, offsetx, y, offsety, n); + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + y[offsety + iy] = x[offsetx + ix]; + } + } + } + + protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + System.arraycopy(x, offsetx, y, offsety, n); + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + y[offsety + iy] = x[offsetx + ix]; + } + } + } + + protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + double sum = 0.0; + if (incx == 1 && incy == 1) { + int ix = 0, iy = 0; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) { + sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0]; + sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1]; + sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2]; + sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3]; + } + sum += sum0 + sum1 + sum2 + sum3; + for (; ix < n && iy < n; ix += 1, iy += 1) { + sum += x[offsetx + ix] * y[offsety + iy]; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + sum += x[offsetx + ix] * y[offsety + iy]; + } + } + return sum; + } + + protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + float sum = 0.0f; + if (incx == 1 && incy == 1) { + int ix = 0, iy = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) { + sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0]; + sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1]; + sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2]; + sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3]; + } + sum += sum0 + sum1 + sum2 + sum3; + for (; ix < n && iy < n; ix += 1, iy += 1) { + sum += x[offsetx + ix] * y[offsety + iy]; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + sum += x[offsetx + ix] * y[offsety + iy]; + } + } + return sum; + } + + protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + double sum = sb; + if (incx == 1 && incy == 1) { + int ix = 0, iy = 0; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) { + sum0 += (double)x[offsetx + ix + 0] * (double)y[offsety + iy + 0]; + sum1 += (double)x[offsetx + ix + 1] * (double)y[offsety + iy + 1]; + sum2 += (double)x[offsetx + ix + 2] * (double)y[offsety + iy + 2]; + sum3 += (double)x[offsetx + ix + 3] * (double)y[offsety + iy + 3]; + } + sum += sum0 + sum1 + sum2 + sum3; + for (; ix < n && iy < n; ix += 1, iy += 1) { + sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]); + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]); + } + } + return (float)sum; + } + + protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + if (alpha == 0.0) { + dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc); + } else if (m * n * k < 100 * 100 * 100) { + // The matrices are small and it's faster to do the non-copying version + if (lsame("N", transa) && lsame("N", transb)) { + dgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("N", transa)) { + dgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("N", transb)) { + dgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else { + dgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + } else { + final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3), + Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3), + Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4); + + assert Krow > 0; + assert Kcol > 0; + assert Ki > 0; + + double[] packeda = new double[Krow * Ki]; + double[] packedb = new double[Kcol * Ki]; + double[] packedc = new double[Kcol * Krow]; + + // c = beta * c + dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc); + // c += alpha * a * b + for (int col = 0; col < n; col += Kcol) { + int cols = col, cole = Math.min(col + Kcol, n); + for (int i = 0; i < k; i += Ki) { + int is = i, ie = Math.min(i + Ki, k); + // pack b + if (lsame("N", transb)) { + dgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0); + } else { + dgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0); + } + // GEPP + for (int row = 0; row < m; row += Krow) { + int rows = row, rowe = Math.min(row + Krow, m); + // pack A + if (lsame("N", transa)) { + dgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0); + } else { + dgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0); + } + // pack C + dgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0); + // GEBP + dgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is, + alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow); + // unpack C + dgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols); + } + } + } + } + } + + protected void dgemmBeta(int rows, int rowe, int cols, int cole, double beta, double[] c, int offsetc, int ldc) { + if (beta != 1.0) { + int col = cols; + for (; col < loopAlign(cols, cole, 4); col += 1) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0; + } + } + } + for (; col < loopBound(cole, 4); col += 4) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0; + c[offsetc + row + (col + 1) * ldc] = 0.0; + c[offsetc + row + (col + 2) * ldc] = 0.0; + c[offsetc + row + (col + 3) * ldc] = 0.0; + } + } + } + for (; col < cole; col += 1) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0; + } + } + } + } + } + + protected void dgecpyNN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m); + } + for (; col < n; col += 1) { + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m); + } + } + + protected void dgecpyNT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc]; + } + for (; row < m; row += 1) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc]; + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc]; + } + for (; row < m; row += 1) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + } + } + } + + protected void dgecpyTN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc]; + } + for (; col < n; col += 1) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc]; + } + } + for (; row < m; row += 1) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc]; + } + for (; col < n; col += 1) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + } + } + } + + protected void dgecpyTT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n); + } + for (; row < m; row += 1) { + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n); + } + } + + protected void dgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Tcol = 3, Trow = 3; + + int col = cols; + for (; col < loopAlign(cols, cole, Tcol); col += 1) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + double sum00 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + } + for (; row < rowe; row += 1) { + double sum00 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + } + for (; col < loopBound(cole, Tcol); col += Tcol) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum03 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + double b1 = b[offsetb + i + (col + 1) * ldb]; + double b2 = b[offsetb + i + (col + 2) * ldb]; + sum00 = a0 * b0 + sum00; + sum01 = a0 * b1 + sum01; + sum02 = a0 * b2 + sum02; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + dgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + for (; row < rowe; row += 1) { + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + double b1 = b[offsetb + i + (col + 1) * ldb]; + double b2 = b[offsetb + i + (col + 2) * ldb]; + sum00 = a0 * b0 + sum00; + sum01 = a0 * b1 + sum01; + sum02 = a0 * b2 + sum02; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + } + } + for (; col < cole; col += 1) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + double sum00 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + } + for (; row < rowe; row += 1) { + double sum00 = 0.0; + for (int i = is; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + } + } + + protected void dgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Ti = 2; + + assert rowe - rows == 3; + assert cole - cols == 3; + + int row = rows; + int col = cols; + int i = is; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum10 = 0.0; + double sum11 = 0.0; + double sum12 = 0.0; + double sum20 = 0.0; + double sum21 = 0.0; + double sum22 = 0.0; + for (; i < loopAlign(is, ie, Ti); i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + double b1 = b[offsetb + i + (col + 1) * ldb]; + sum01 = a0 * b1 + sum01; + sum11 = a1 * b1 + sum11; + sum21 = a2 * b1 + sum21; + double b2 = b[offsetb + i + (col + 2) * ldb]; + sum02 = a0 * b2 + sum02; + sum12 = a1 * b2 + sum12; + sum22 = a2 * b2 + sum22; + } + for (; i < loopBound(ie, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a01 = a[offseta + (i + 0) + (row + 1) * lda]; + double a02 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a01 * b00 + sum10; + sum20 = a02 * b00 + sum20; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + sum01 = a00 * b01 + sum01; + sum11 = a01 * b01 + sum11; + sum21 = a02 * b01 + sum21; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum02 = a00 * b02 + sum02; + sum12 = a01 * b02 + sum12; + sum22 = a02 * b02 + sum22; + double a10 = a[offseta + (i + 1) + (row + 0) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a12 = a[offseta + (i + 1) + (row + 2) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a10 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a12 * b10 + sum20; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + sum01 = a10 * b11 + sum01; + sum11 = a11 * b11 + sum11; + sum21 = a12 * b11 + sum21; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum02 = a10 * b12 + sum02; + sum12 = a11 * b12 + sum12; + sum22 = a12 * b12 + sum22; + } + for (; i < ie; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + double b1 = b[offsetb + i + (col + 1) * ldb]; + sum01 = a0 * b1 + sum01; + sum11 = a1 * b1 + sum11; + sum21 = a2 * b1 + sum21; + double b2 = b[offsetb + i + (col + 2) * ldb]; + sum02 = a0 * b2 + sum02; + sum12 = a1 * b2 + sum12; + sum22 = a2 * b2 + sum22; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc]; + } + + protected void dgemmNN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum10 = 0.0; + double sum11 = 0.0; + double sum12 = 0.0; + double sum20 = 0.0; + double sum21 = 0.0; + double sum22 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double a11 = a[offseta + (row + 1) + (i + 1) * lda]; + double a21 = a[offseta + (row + 2) + (i + 1) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double a11 = a[offseta + (row + 1) + (i + 1) * lda]; + double a21 = a[offseta + (row + 2) + (i + 1) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void dgemmNT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum10 = 0.0; + double sum11 = 0.0; + double sum12 = 0.0; + double sum20 = 0.0; + double sum21 = 0.0; + double sum22 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double a11 = a[offseta + (row + 1) + (i + 1) * lda]; + double a21 = a[offseta + (row + 2) + (i + 1) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + double b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + double b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + double b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + double b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double a11 = a[offseta + (row + 1) + (i + 1) * lda]; + double a21 = a[offseta + (row + 2) + (i + 1) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double a10 = a[offseta + (row + 1) + (i + 0) * lda]; + double a20 = a[offseta + (row + 2) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + double a01 = a[offseta + (row + 0) + (i + 1) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (row + 0) + (i + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void dgemmTN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum10 = 0.0; + double sum11 = 0.0; + double sum12 = 0.0; + double sum20 = 0.0; + double sum21 = 0.0; + double sum22 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a21 = a[offseta + (i + 1) + (row + 2) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a21 = a[offseta + (i + 1) + (row + 2) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void dgemmTT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + double sum10 = 0.0; + double sum11 = 0.0; + double sum12 = 0.0; + double sum20 = 0.0; + double sum21 = 0.0; + double sum22 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a21 = a[offseta + (i + 1) + (row + 2) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + double b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + double b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + double sum01 = 0.0; + double sum02 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + double b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + double b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + double b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + double b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a21 = a[offseta + (i + 1) + (row + 2) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a10 = a[offseta + (i + 0) + (row + 1) * lda]; + double a20 = a[offseta + (i + 0) + (row + 2) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + double sum00 = 0.0; + for (; i < loopBound(k, Ti); i += Ti) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + double a01 = a[offseta + (i + 1) + (row + 0) * lda]; + double b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + if (alpha == 0.0f) { + sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc); + } else if (m * n * k < 100 * 100 * 100) { + // The matrices are small and it's faster to do the non-copying version + if (lsame("N", transa) && lsame("N", transb)) { + sgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("N", transa)) { + sgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("N", transb)) { + sgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else { + sgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + } else { + final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3), + Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3), + Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4); + + assert Krow > 0; + assert Kcol > 0; + assert Ki > 0; + + float[] packeda = new float[Krow * Ki]; + float[] packedb = new float[Kcol * Ki]; + float[] packedc = new float[Kcol * Krow]; + + // c = beta * c + sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc); + // c += alpha * a * b + for (int col = 0; col < n; col += Kcol) { + int cols = col, cole = Math.min(col + Kcol, n); + for (int i = 0; i < k; i += Ki) { + int is = i, ie = Math.min(i + Ki, k); + // pack b + if (lsame("N", transb)) { + sgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0); + } else { + sgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0); + } + // GEPP + for (int row = 0; row < m; row += Krow) { + int rows = row, rowe = Math.min(row + Krow, m); + // pack A + if (lsame("N", transa)) { + sgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0); + } else { + sgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0); + } + // pack C + sgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0); + // GEBP + sgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is, + alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow); + // unpack C + sgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols); + } + } + } + } + } + + protected void sgemmBeta(int rows, int rowe, int cols, int cole, float beta, float[] c, int offsetc, int ldc) { + if (beta != 1.0f) { + int col = cols; + for (; col < loopAlign(cols, cole, 4); col += 1) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0f) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0f; + } + } + } + for (; col < loopBound(cole, 4); col += 4) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0f) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0f; + c[offsetc + row + (col + 1) * ldc] = 0.0f; + c[offsetc + row + (col + 2) * ldc] = 0.0f; + c[offsetc + row + (col + 3) * ldc] = 0.0f; + } + } + } + for (; col < cole; col += 1) { + int row = rows; + for (; row < rowe; row += 1) { + if (beta != 0.0f) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0f; + } + } + } + } + } + + protected void sgecpyNN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m); + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m); + } + for (; col < n; col += 1) { + System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m); + } + } + + protected void sgecpyNT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc]; + } + for (; row < m; row += 1) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc]; + dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc]; + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc]; + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc]; + } + for (; row < m; row += 1) { + dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc]; + } + } + } + + protected void sgecpyTN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int row = 0; + for (; row < loopBound(m, 3); row += 3) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc]; + } + for (; col < n; col += 1) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc]; + dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc]; + } + } + for (; row < m; row += 1) { + int col = 0; + for (; col < loopBound(n, 3); col += 3) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc]; + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc]; + } + for (; col < n; col += 1) { + dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc]; + } + } + } + + protected void sgecpyTT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n); + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n); + } + for (; row < m; row += 1) { + System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n); + } + } + + protected void sgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Tcol = 3, Trow = 3, Ti = 2; + + int col = cols; + for (; col < loopAlign(cols, cole, Tcol); col += 1) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + float sum00 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + } + for (; row < rowe; row += 1) { + float sum00 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + } + for (; col < loopBound(cole, Tcol); col += Tcol) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum03 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + float b1 = b[offsetb + i + (col + 1) * ldb]; + float b2 = b[offsetb + i + (col + 2) * ldb]; + sum00 = a0 * b0 + sum00; + sum01 = a0 * b1 + sum01; + sum02 = a0 * b2 + sum02; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + sgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + for (; row < rowe; row += 1) { + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + float b1 = b[offsetb + i + (col + 1) * ldb]; + float b2 = b[offsetb + i + (col + 2) * ldb]; + sum00 = a0 * b0 + sum00; + sum01 = a0 * b1 + sum01; + sum02 = a0 * b2 + sum02; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + } + } + for (; col < cole; col += 1) { + int row = rows; + for (; row < loopAlign(rows, rowe, Trow); row += 1) { + float sum00 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + for (; row < loopBound(rowe, Trow); row += Trow) { + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + } + for (; row < rowe; row += 1) { + float sum00 = 0.0f; + for (int i = is; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + } + } + } + + protected void sgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Ti = 2; + + assert rowe - rows == 3; + assert cole - cols == 3; + + int row = rows; + int col = cols; + int i = is; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + float sum12 = 0.0f; + float sum20 = 0.0f; + float sum21 = 0.0f; + float sum22 = 0.0f; + for (; i < loopAlign(is, ie, Ti); i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + float b1 = b[offsetb + i + (col + 1) * ldb]; + sum01 = a0 * b1 + sum01; + sum11 = a1 * b1 + sum11; + sum21 = a2 * b1 + sum21; + float b2 = b[offsetb + i + (col + 2) * ldb]; + sum02 = a0 * b2 + sum02; + sum12 = a1 * b2 + sum12; + sum22 = a2 * b2 + sum22; + } + for (; i < loopBound(ie, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a01 = a[offseta + (i + 0) + (row + 1) * lda]; + float a02 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a01 * b00 + sum10; + sum20 = a02 * b00 + sum20; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + sum01 = a00 * b01 + sum01; + sum11 = a01 * b01 + sum11; + sum21 = a02 * b01 + sum21; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum02 = a00 * b02 + sum02; + sum12 = a01 * b02 + sum12; + sum22 = a02 * b02 + sum22; + float a10 = a[offseta + (i + 1) + (row + 0) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a12 = a[offseta + (i + 1) + (row + 2) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a10 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a12 * b10 + sum20; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + sum01 = a10 * b11 + sum01; + sum11 = a11 * b11 + sum11; + sum21 = a12 * b11 + sum21; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum02 = a10 * b12 + sum02; + sum12 = a11 * b12 + sum12; + sum22 = a12 * b12 + sum22; + } + for (; i < ie; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float b0 = b[offsetb + i + (col + 0) * ldb]; + sum00 = a0 * b0 + sum00; + sum10 = a1 * b0 + sum10; + sum20 = a2 * b0 + sum20; + float b1 = b[offsetb + i + (col + 1) * ldb]; + sum01 = a0 * b1 + sum01; + sum11 = a1 * b1 + sum11; + sum21 = a2 * b1 + sum21; + float b2 = b[offsetb + i + (col + 2) * ldb]; + sum02 = a0 * b2 + sum02; + sum12 = a1 * b2 + sum12; + sum22 = a2 * b2 + sum22; + } + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc]; + } + + protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + float sum12 = 0.0f; + float sum20 = 0.0f; + float sum21 = 0.0f; + float sum22 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float a11 = a[offseta + (row + 1) + (i + 1) * lda]; + float a21 = a[offseta + (row + 2) + (i + 1) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float a11 = a[offseta + (row + 1) + (i + 1) * lda]; + float a21 = a[offseta + (row + 2) + (i + 1) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + float sum12 = 0.0f; + float sum20 = 0.0f; + float sum21 = 0.0f; + float sum22 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float a11 = a[offseta + (row + 1) + (i + 1) * lda]; + float a21 = a[offseta + (row + 2) + (i + 1) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + float b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + float b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + float b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + float b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float a11 = a[offseta + (row + 1) + (i + 1) * lda]; + float a21 = a[offseta + (row + 2) + (i + 1) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float a10 = a[offseta + (row + 1) + (i + 0) * lda]; + float a20 = a[offseta + (row + 2) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + float a01 = a[offseta + (row + 0) + (i + 1) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (row + 0) + (i + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + float sum12 = 0.0f; + float sum20 = 0.0f; + float sum21 = 0.0f; + float sum22 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a21 = a[offseta + (i + 1) + (row + 2) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a21 = a[offseta + (i + 1) + (row + 2) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Trow = 3, Tcol = 3, Ti = 2; + + int col = 0; + for (; col < loopBound(n, Tcol); col += Tcol) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + float sum12 = 0.0f; + float sum20 = 0.0f; + float sum21 = 0.0f; + float sum22 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a21 = a[offseta + (i + 1) + (row + 2) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + float b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + float b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + sum10 = a11 * b10 + sum10; + sum11 = a11 * b11 + sum11; + sum12 = a11 * b12 + sum12; + sum20 = a21 * b10 + sum20; + sum21 = a21 * b11 + sum21; + sum22 = a21 * b12 + sum22; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + sum10 = a10 * b00 + sum10; + sum11 = a10 * b01 + sum11; + sum12 = a10 * b02 + sum12; + sum20 = a20 * b00 + sum20; + sum21 = a20 * b01 + sum21; + sum22 = a20 * b02 + sum22; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum02 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + float b11 = b[offsetb + (col + 1) + (i + 1) * ldb]; + float b12 = b[offsetb + (col + 2) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum01 = a01 * b11 + sum01; + sum02 = a01 * b12 + sum02; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + float b01 = b[offsetb + (col + 1) + (i + 0) * ldb]; + float b02 = b[offsetb + (col + 2) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum01 = a00 * b01 + sum01; + sum02 = a00 * b02 + sum02; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, Trow); row += Trow) { + int i = 0; + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a21 = a[offseta + (i + 1) + (row + 2) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + sum10 = a11 * b10 + sum10; + sum20 = a21 * b10 + sum20; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a10 = a[offseta + (i + 0) + (row + 1) * lda]; + float a20 = a[offseta + (i + 0) + (row + 2) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + sum10 = a10 * b00 + sum10; + sum20 = a20 * b00 + sum20; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + } + } + for (; row < m; row += 1) { + int i = 0; + float sum00 = 0.0f; + for (; i < loopBound(k, Ti); i += Ti) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + float a01 = a[offseta + (i + 1) + (row + 0) * lda]; + float b10 = b[offsetb + (col + 0) + (i + 1) * ldb]; + sum00 = a01 * b10 + sum00; + } + for (; i < k; i += 1) { + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float b00 = b[offsetb + (col + 0) + (i + 0) * ldb]; + sum00 = a00 * b00 + sum00; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + } + } + } + } + + protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (alpha == 0.0) { + int len = lsame("N", trans) ? m : n; + for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } else if (lsame("N", trans)) { + dgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("T", trans) || lsame("C", trans)) { + dgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void dgemvN(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (beta != 1.0) { + int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; + for (; row < m; row += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4) { + int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; + double alphax0 = alpha * x[offsetx + ix + incx * 0]; + double alphax1 = alpha * x[offsetx + ix + incx * 1]; + double alphax2 = alpha * x[offsetx + ix + incx * 2]; + double alphax3 = alpha * x[offsetx + ix + incx * 3]; + for (; row < m; row += 1, iy += incy) { + y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda] + + alphax1 * a[offseta + row + (col + 1) * lda] + + alphax2 * a[offseta + row + (col + 2) * lda] + + alphax3 * a[offseta + row + (col + 3) * lda]; + } + } + for (; col < n; col += 1, ix += incx) { + int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; + double alphax = alpha * x[offsetx + ix]; + for (; row < m; row += 1, iy += incy) { + y[offsety + iy] += alphax * a[offseta + row + col * lda]; + } + } + } + + protected void dgemvT(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, iy += incy * 4) { + int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (; row < m; row += 1, ix += incx) { + double xix = x[offsetx + ix]; + sum0 += xix * a[offseta + row + (col + 0) * lda]; + sum1 += xix * a[offseta + row + (col + 1) * lda]; + sum2 += xix * a[offseta + row + (col + 2) * lda]; + sum3 += xix * a[offseta + row + (col + 3) * lda]; + } + if (beta != 0.0) { + y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3]; + } else { + y[offsety + iy + incy * 0] = alpha * sum0; + y[offsety + iy + incy * 1] = alpha * sum1; + y[offsety + iy + incy * 2] = alpha * sum2; + y[offsety + iy + incy * 3] = alpha * sum3; + } + } + for (; col < n; col += 1, iy += incy) { + int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; + double sum = 0.0; + for (; row < m; row += 1, ix += incx) { + sum += x[offsetx + ix] * a[offseta + row + col * lda]; + } + if (beta != 0.0) { + y[offsety + iy] = alpha * sum + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sum; + } + } + } + + protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (alpha == 0.0f) { + int len = lsame("N", trans) ? m : n; + for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + } else if (lsame("N", trans)) { + sgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("T", trans) || lsame("C", trans)) { + sgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void sgemvN(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + // y = beta * y + for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + // y += alpha * A * x + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0; + for (; col < loopBound(n, 8); col += 8, ix += incx * 8) { + float alphax0 = alpha * x[offsetx + ix + incx * 0]; + float alphax1 = alpha * x[offsetx + ix + incx * 1]; + float alphax2 = alpha * x[offsetx + ix + incx * 2]; + float alphax3 = alpha * x[offsetx + ix + incx * 3]; + float alphax4 = alpha * x[offsetx + ix + incx * 4]; + float alphax5 = alpha * x[offsetx + ix + incx * 5]; + float alphax6 = alpha * x[offsetx + ix + incx * 6]; + float alphax7 = alpha * x[offsetx + ix + incx * 7]; + for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) { + y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda] + + alphax1 * a[offseta + row + (col + 1) * lda] + + alphax2 * a[offseta + row + (col + 2) * lda] + + alphax3 * a[offseta + row + (col + 3) * lda] + + alphax4 * a[offseta + row + (col + 4) * lda] + + alphax5 * a[offseta + row + (col + 5) * lda] + + alphax6 * a[offseta + row + (col + 6) * lda] + + alphax7 * a[offseta + row + (col + 7) * lda]; + } + } + for (; col < n; col += 1, ix += incx) { + float alphax = alpha * x[offsetx + ix]; + for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) { + y[offsety + iy] += alphax * a[offseta + row + col * lda]; + } + } + } + + protected void sgemvT(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 8); col += 8, iy += incy * 8) { + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + float sum4 = 0.0f; + float sum5 = 0.0f; + float sum6 = 0.0f; + float sum7 = 0.0f; + for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) { + sum0 += x[offsetx + ix] * a[offseta + row + (col + 0) * lda]; + sum1 += x[offsetx + ix] * a[offseta + row + (col + 1) * lda]; + sum2 += x[offsetx + ix] * a[offseta + row + (col + 2) * lda]; + sum3 += x[offsetx + ix] * a[offseta + row + (col + 3) * lda]; + sum4 += x[offsetx + ix] * a[offseta + row + (col + 4) * lda]; + sum5 += x[offsetx + ix] * a[offseta + row + (col + 5) * lda]; + sum6 += x[offsetx + ix] * a[offseta + row + (col + 6) * lda]; + sum7 += x[offsetx + ix] * a[offseta + row + (col + 7) * lda]; + } + if (beta != 0.0f) { + y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3]; + y[offsety + iy + incy * 4] = alpha * sum4 + beta * y[offsety + iy + incy * 4]; + y[offsety + iy + incy * 5] = alpha * sum5 + beta * y[offsety + iy + incy * 5]; + y[offsety + iy + incy * 6] = alpha * sum6 + beta * y[offsety + iy + incy * 6]; + y[offsety + iy + incy * 7] = alpha * sum7 + beta * y[offsety + iy + incy * 7]; + } else { + y[offsety + iy + incy * 0] = alpha * sum0; + y[offsety + iy + incy * 1] = alpha * sum1; + y[offsety + iy + incy * 2] = alpha * sum2; + y[offsety + iy + incy * 3] = alpha * sum3; + y[offsety + iy + incy * 4] = alpha * sum4; + y[offsety + iy + incy * 5] = alpha * sum5; + y[offsety + iy + incy * 6] = alpha * sum6; + y[offsety + iy + incy * 7] = alpha * sum7; + } + } + for (; col < n; col += 1, iy += incy) { + float sum = 0.0f; + for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) { + sum += x[offsetx + ix] * a[offseta + row + col * lda]; + } + if (beta != 0.0f) { + y[offsety + iy] = alpha * sum + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sum; + } + } + } + + protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, iy += incy * 4) { + double alphayiy0 = alpha * y[offsety + iy + incy * 0]; + double alphayiy1 = alpha * y[offsety + iy + incy * 1]; + double alphayiy2 = alpha * y[offsety + iy + incy * 2]; + double alphayiy3 = alpha * y[offsety + iy + incy * 3]; + int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0; + for (; row < m; row += 1, jx += incx) { + double xjx = x[offsetx + jx]; + a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx; + a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx; + a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx; + a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx; + } + } + for (; col < n; col += 1, iy += incy) { + double alphayiy = alpha * y[offsety + iy]; + int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0; + for (; row < m; row += 1, jx += incx) { + a[offseta + row + col * lda] += alphayiy * x[offsetx + jx]; + } + } + } + + protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, iy += incy * 4) { + float alphayiy0 = alpha * y[offsety + iy + incy * 0]; + float alphayiy1 = alpha * y[offsety + iy + incy * 1]; + float alphayiy2 = alpha * y[offsety + iy + incy * 2]; + float alphayiy3 = alpha * y[offsety + iy + incy * 3]; + int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0; + for (; row < m; row += 1, jx += incx) { + float xjx = x[offsetx + jx]; + a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx; + a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx; + a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx; + a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx; + } + } + for (; col < n; col += 1, iy += incy) { + float alphayiy = alpha * y[offsety + iy]; + int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0; + for (; row < m; row += 1, jx += incx) { + a[offseta + row + col * lda] += alphayiy * x[offsetx + jx]; + } + } + } + + protected double dnrm2K(int n, double[] x, int offsetx, int incx) { + int ix = 0; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + if (incx == 1) { + for (; ix < loopBound(n, 4); ix += 4) { + double x0 = x[offsetx + ix + 0]; + double x1 = x[offsetx + ix + 1]; + double x2 = x[offsetx + ix + 2]; + double x3 = x[offsetx + ix + 3]; + sum0 += x0 * x0; + sum1 += x1 * x1; + sum2 += x2 * x2; + sum3 += x3 * x3; + } + } else { + for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) { + double x0 = x[offsetx + ix + (0 * incx)]; + double x1 = x[offsetx + ix + (1 * incx)]; + double x2 = x[offsetx + ix + (2 * incx)]; + double x3 = x[offsetx + ix + (3 * incx)]; + sum0 += x0 * x0; + sum1 += x1 * x1; + sum2 += x2 * x2; + sum3 += x3 * x3; + } + } + double sum = sum0 + sum1 + sum2 + sum3; + for (; ix < n * incx; ix += incx) { + double x0 = x[offsetx + ix + 0]; + sum += x0 * x0; + } + return Math.sqrt(sum); + } + + protected float snrm2K(int n, float[] x, int offsetx, int incx) { + int ix = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + if (incx == 1) { + for (; ix < loopBound(n, 4); ix += 4) { + float x0 = x[offsetx + ix + 0]; + float x1 = x[offsetx + ix + 1]; + float x2 = x[offsetx + ix + 2]; + float x3 = x[offsetx + ix + 3]; + sum0 += x0 * x0; + sum1 += x1 * x1; + sum2 += x2 * x2; + sum3 += x3 * x3; + } + } else { + for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) { + float x0 = x[offsetx + ix + (0 * incx)]; + float x1 = x[offsetx + ix + (1 * incx)]; + float x2 = x[offsetx + ix + (2 * incx)]; + float x3 = x[offsetx + ix + (3 * incx)]; + sum0 += x0 * x0; + sum1 += x1 * x1; + sum2 += x2 * x2; + sum3 += x3 * x3; + } + } + float sum = sum0 + sum1 + sum2 + sum3; + for (; ix < n * incx; ix += incx) { + float x0 = x[offsetx + ix + 0]; + sum += x0 * x0; + } + return (float)Math.sqrt(sum); + } + + protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + double x0 = x[offsetx + ix]; + double y0 = y[offsety + iy]; + x[offsetx + ix] = c * x0 + s * y0; + y[offsety + iy] = c * y0 - s * x0; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + double x0 = x[offsetx + ix]; + double y0 = y[offsety + iy]; + x[offsetx + ix] = c * x0 + s * y0; + y[offsety + iy] = c * y0 - s * x0; + } + } + } + + protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + float x0 = x[offsetx + ix]; + float y0 = y[offsety + iy]; + x[offsetx + ix] = c * x0 + s * y0; + y[offsety + iy] = c * y0 - s * x0; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + float x0 = x[offsetx + ix]; + float y0 = y[offsety + iy]; + x[offsetx + ix] = c * x0 + s * y0; + y[offsety + iy] = c * y0 - s * x0; + } + } + } + + protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) { + org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + + protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) { + org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam); + } + + protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) { + org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam); + } + + protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) { + org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam); + } + + protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + + protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) { + if (incx == 1) { + for (int ix = 0; ix < n; ix += 1) { + x[offsetx + ix] *= alpha; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) { + x[offsetx + ix] *= alpha; + } + } + } + + protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) { + if (incx == 1) { + for (int ix = 0; ix < n; ix += 1) { + x[offsetx + ix] *= alpha; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) { + x[offsetx + ix] *= alpha; + } + } + } + + protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (alpha == 0.0) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } else if (lsame("U", uplo)) { + dspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("L", uplo)) { + dspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void dspmvU(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + double alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + double alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + double alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + double alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + double sumiy0 = 0.0; + double sumiy1 = 0.0; + double sumiy2 = 0.0; + double sumiy3 = 0.0; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + double a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2]; + double a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2]; + double a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2]; + double a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + double xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + double a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2]; + double a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2]; + double a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2]; + double a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2]; + double a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2]; + double a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2]; + double a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2]; + double a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2]; + double a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2]; + double a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2]; + double xjx0 = x[offsetx + jx + incx * 0]; + double xjx1 = x[offsetx + jx + incx * 1]; + double xjx2 = x[offsetx + jx + incx * 2]; + double xjx3 = x[offsetx + jx + incx * 3]; + sumiy0 += xjx0 * a00 + + xjx1 * a01 + + xjx2 * a02 + + xjx3 * a03; + sumiy1 += xjx0 * a01 + + xjx1 * a11 + + xjx2 * a12 + + xjx3 * a13; + sumiy2 += xjx0 * a02 + + xjx1 * a12 + + xjx2 * a22 + + xjx3 * a23; + sumiy3 += xjx0 * a03 + + xjx1 * a13 + + xjx2 * a23 + + xjx3 * a33; + if (beta != 0.0) { + y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3]; + } else { + y[offsety + iy + incy * 0] = alpha * sumiy0; + y[offsety + iy + incy * 1] = alpha * sumiy1; + y[offsety + iy + incy * 2] = alpha * sumiy2; + y[offsety + iy + incy * 3] = alpha * sumiy3; + } + } + for (; col < n; col += 1, ix += incx, iy += incy) { + double alphaxix = alpha * x[offsetx + ix]; + double sumiy = 0.0; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2]; + sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2]; + } + sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2]; + if (beta != 0.0) { + y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sumiy; + } + } + } + + protected void dspmvL(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + // y = beta * y + if (beta != 1.0) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } + // y += alpha * A * x + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + double alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + double alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + double alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + double alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + double sumiy0 = 0.0; + double sumiy1 = 0.0; + double sumiy2 = 0.0; + double sumiy3 = 0.0; + double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2]; + double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)]; + double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)]; + double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)]; + double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)]; + sumiy0 += x0 * a00 + + x1 * a10 + + x2 * a20 + + x3 * a30; + sumiy1 += x0 * a10 + + x1 * a11 + + x2 * a21 + + x3 * a31; + sumiy2 += x0 * a20 + + x1 * a21 + + x2 * a22 + + x3 * a32; + sumiy3 += x0 * a30 + + x1 * a31 + + x2 * a32 + + x3 * a33; + int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + double a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + double a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + double a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + double a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + double xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + y[offsety + iy + incy * 0] += alpha * sumiy0; + y[offsety + iy + incy * 1] += alpha * sumiy1; + y[offsety + iy + incy * 2] += alpha * sumiy2; + y[offsety + iy + incy * 3] += alpha * sumiy3; + } + for (; col < n; col += 1, ix += incx, iy += incy) { + double alphaxix = alpha * x[offsetx + ix]; + double sumiy = 0.0; + sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2]; + int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2]; + sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2]; + } + y[offsety + iy] += alpha * sumiy; + } + } + + protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (alpha == 0.0f) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + } else if (lsame("U", uplo)) { + sspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("L", uplo)) { + sspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void sspmvU(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + float alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + float alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + float alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + float alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + float sumiy0 = 0.0f; + float sumiy1 = 0.0f; + float sumiy2 = 0.0f; + float sumiy3 = 0.0f; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + float a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2]; + float a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2]; + float a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2]; + float a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + float xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + float a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2]; + float a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2]; + float a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2]; + float a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2]; + float a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2]; + float a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2]; + float a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2]; + float a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2]; + float a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2]; + float a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2]; + float xjx0 = x[offsetx + jx + incx * 0]; + float xjx1 = x[offsetx + jx + incx * 1]; + float xjx2 = x[offsetx + jx + incx * 2]; + float xjx3 = x[offsetx + jx + incx * 3]; + sumiy0 += xjx0 * a00 + + xjx1 * a01 + + xjx2 * a02 + + xjx3 * a03; + sumiy1 += xjx0 * a01 + + xjx1 * a11 + + xjx2 * a12 + + xjx3 * a13; + sumiy2 += xjx0 * a02 + + xjx1 * a12 + + xjx2 * a22 + + xjx3 * a23; + sumiy3 += xjx0 * a03 + + xjx1 * a13 + + xjx2 * a23 + + xjx3 * a33; + if (beta != 0.0f) { + y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3]; + } else { + y[offsety + iy + incy * 0] = alpha * sumiy0; + y[offsety + iy + incy * 1] = alpha * sumiy1; + y[offsety + iy + incy * 2] = alpha * sumiy2; + y[offsety + iy + incy * 3] = alpha * sumiy3; + } + } + for (; col < n; col += 1, ix += incx, iy += incy) { + float alphaxix = alpha * x[offsetx + ix]; + float sumiy = 0.0f; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2]; + sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2]; + } + sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2]; + if (beta != 0.0f) { + y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sumiy; + } + } + } + + protected void sspmvL(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + // y = beta * y + if (beta != 1.0f) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + } + // y += alpha * A * x + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + float alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + float alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + float alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + float alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + float sumiy0 = 0.0f; + float sumiy1 = 0.0f; + float sumiy2 = 0.0f; + float sumiy3 = 0.0f; + float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2]; + float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)]; + float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)]; + float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)]; + float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)]; + sumiy0 += x0 * a00 + + x1 * a10 + + x2 * a20 + + x3 * a30; + sumiy1 += x0 * a10 + + x1 * a11 + + x2 * a21 + + x3 * a31; + sumiy2 += x0 * a20 + + x1 * a21 + + x2 * a22 + + x3 * a32; + sumiy3 += x0 * a30 + + x1 * a31 + + x2 * a32 + + x3 * a33; + int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + float a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2]; + float a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2]; + float a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2]; + float a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + float xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + y[offsety + iy + incy * 0] += alpha * sumiy0; + y[offsety + iy + incy * 1] += alpha * sumiy1; + y[offsety + iy + incy * 2] += alpha * sumiy2; + y[offsety + iy + incy * 3] += alpha * sumiy3; + } + for (; col < n; col += 1, ix += incx, iy += incy) { + float alphaxix = alpha * x[offsetx + ix]; + float sumiy = 0.0f; + sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2]; + int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2]; + sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2]; + } + y[offsety + iy] += alpha * sumiy; + } + } + + protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) { + org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + + protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) { + org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta); + } + + protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) { + org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + + protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) { + org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta); + } + + protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + double tmp = y[offsety + iy]; + y[offsety + iy] = x[offsetx + ix]; + x[offsetx + ix] = tmp; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + double tmp = y[offsety + iy]; + y[offsety + iy] = x[offsetx + ix]; + x[offsetx + ix] = tmp; + } + } + } + + protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) { + if (incx == 1 && incy == 1) { + for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) { + float tmp = y[offsety + iy]; + y[offsety + iy] = x[offsetx + ix]; + x[offsetx + ix] = tmp; + } + } else { + for (int ix = incx < 0 ? (n - 1) * -incx : 0, + iy = incy < 0 ? (n - 1) * -incy : 0; + (incx < 0 ? ix >= 0 : ix < n * incx) + && (incy < 0 ? iy >= 0 : iy < n * incy); + ix += incx, iy += incy) { + float tmp = y[offsety + iy]; + y[offsety + iy] = x[offsetx + ix]; + x[offsetx + ix] = tmp; + } + } + } + + protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + if (alpha == 0.0) { + // C := beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = 0; + for (; row < m; row += 1) { + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = 0.0; + c[offsetc + row + (col + 1) * ldc] = 0.0; + c[offsetc + row + (col + 2) * ldc] = 0.0; + c[offsetc + row + (col + 3) * ldc] = 0.0; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < m; row += 1) { + if (beta != 0.0) { + c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc]; + } else { + c[offsetc + row + col * ldc] = 0.0; + } + } + } + } else if (lsame("L", side) && lsame("U", uplo)) { + dsymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("L", side) && lsame("L", uplo)) { + dsymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("R", side) && lsame("U", uplo)) { + dsymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("R", side) && lsame("L", uplo)) { + dsymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + } + + protected void dsymmLU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + // C := alpha*A*B + beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + double sum30 = 0.0; + double sum01 = 0.0; + double sum11 = 0.0; + double sum21 = 0.0; + double sum31 = 0.0; + double sum02 = 0.0; + double sum12 = 0.0; + double sum22 = 0.0; + double sum32 = 0.0; + double sum03 = 0.0; + double sum13 = 0.0; + double sum23 = 0.0; + double sum33 = 0.0; + double alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb]; + double alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb]; + double alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb]; + double alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb]; + double alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb]; + double alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb]; + double alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb]; + double alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb]; + double alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb]; + double alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb]; + double alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb]; + double alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb]; + double alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb]; + double alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb]; + double alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb]; + double alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb]; + int i = 0; + for (; i < row; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab00 * a0 + + alphab10 * a1 + + alphab20 * a2 + + alphab30 * a3; + c[offsetc + i + (col + 1) * ldc] += alphab01 * a0 + + alphab11 * a1 + + alphab21 * a2 + + alphab31 * a3; + c[offsetc + i + (col + 2) * ldc] += alphab02 * a0 + + alphab12 * a1 + + alphab22 * a2 + + alphab32 * a3; + c[offsetc + i + (col + 3) * ldc] += alphab03 * a0 + + alphab13 * a1 + + alphab23 * a2 + + alphab33 * a3; + double b0 = b[offsetb + i + (col + 0) * ldb]; + double b1 = b[offsetb + i + (col + 1) * ldb]; + double b2 = b[offsetb + i + (col + 2) * ldb]; + double b3 = b[offsetb + i + (col + 3) * ldb]; + sum00 += a0 * b0; + sum10 += a1 * b0; + sum20 += a2 * b0; + sum30 += a3 * b0; + sum01 += a0 * b1; + sum11 += a1 * b1; + sum21 += a2 * b1; + sum31 += a3 * b1; + sum02 += a0 * b2; + sum12 += a1 * b2; + sum22 += a2 * b2; + sum32 += a3 * b2; + sum03 += a0 * b3; + sum13 += a1 * b3; + sum23 += a2 * b3; + sum33 += a3 * b3; + } + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a01 = a[offseta + (i + 0) + (row + 1) * lda]; + double a02 = a[offseta + (i + 0) + (row + 2) * lda]; + double a03 = a[offseta + (i + 0) + (row + 3) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a12 = a[offseta + (i + 1) + (row + 2) * lda]; + double a13 = a[offseta + (i + 1) + (row + 3) * lda]; + double a22 = a[offseta + (i + 2) + (row + 2) * lda]; + double a23 = a[offseta + (i + 2) + (row + 3) * lda]; + double a33 = a[offseta + (i + 3) + (row + 3) * lda]; + double b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + double b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + double b20 = b[offsetb + (i + 2) + (col + 0) * ldb]; + double b30 = b[offsetb + (i + 3) + (col + 0) * ldb]; + double b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + double b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + double b21 = b[offsetb + (i + 2) + (col + 1) * ldb]; + double b31 = b[offsetb + (i + 3) + (col + 1) * ldb]; + double b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + double b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + double b22 = b[offsetb + (i + 2) + (col + 2) * ldb]; + double b32 = b[offsetb + (i + 3) + (col + 2) * ldb]; + double b03 = b[offsetb + (i + 0) + (col + 3) * ldb]; + double b13 = b[offsetb + (i + 1) + (col + 3) * ldb]; + double b23 = b[offsetb + (i + 2) + (col + 3) * ldb]; + double b33 = b[offsetb + (i + 3) + (col + 3) * ldb]; + sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30; + sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30; + sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30; + sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30; + sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31; + sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31; + sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31; + sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31; + sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32; + sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32; + sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32; + sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32; + sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33; + sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33; + sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33; + sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33; + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc]; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc]; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc]; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc]; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33; + } + } + for (; row < m; row += 1) { + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb]; + double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb]; + double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb]; + double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb]; + int i = 0; + for (; i < row; i += 1) { + double a0 = a[offseta + i + row * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab0 * a0; + c[offsetc + i + (col + 1) * ldc] += alphab1 * a0; + c[offsetc + i + (col + 2) * ldc] += alphab2 * a0; + c[offsetc + i + (col + 3) * ldc] += alphab3 * a0; + sum0 += b[offsetb + i + (col + 0) * ldb] * a0; + sum1 += b[offsetb + i + (col + 1) * ldb] * a0; + sum2 += b[offsetb + i + (col + 2) * ldb] * a0; + sum3 += b[offsetb + i + (col + 3) * ldb] * a0; + } + double a0 = a[offseta + i + row * lda]; + sum0 += b[offsetb + i + (col + 0) * ldb] * a0; + sum1 += b[offsetb + i + (col + 1) * ldb] * a0; + sum2 += b[offsetb + i + (col + 2) * ldb] * a0; + sum3 += b[offsetb + i + (col + 3) * ldb] * a0; + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb]; + double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb]; + double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb]; + double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb]; + int i = 0; + for (; i < row; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0 + + alphab1 * a1 + + alphab2 * a2 + + alphab3 * a3; + double b0 = b[offsetb + i + col * ldb]; + sum0 += b0 * a0; + sum1 += b0 * a1; + sum2 += b0 * a2; + sum3 += b0 * a3; + } + double a00 = a[offseta + (i + 0) + (row + 0) * lda]; + double a01 = a[offseta + (i + 0) + (row + 1) * lda]; + double a02 = a[offseta + (i + 0) + (row + 2) * lda]; + double a03 = a[offseta + (i + 0) + (row + 3) * lda]; + double a11 = a[offseta + (i + 1) + (row + 1) * lda]; + double a12 = a[offseta + (i + 1) + (row + 2) * lda]; + double a13 = a[offseta + (i + 1) + (row + 3) * lda]; + double a22 = a[offseta + (i + 2) + (row + 2) * lda]; + double a23 = a[offseta + (i + 2) + (row + 3) * lda]; + double a33 = a[offseta + (i + 3) + (row + 3) * lda]; + double b0 = b[offsetb + (i + 0) + col * ldb]; + double b1 = b[offsetb + (i + 1) + col * ldb]; + double b2 = b[offsetb + (i + 2) + col * ldb]; + double b3 = b[offsetb + (i + 3) + col * ldb]; + sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03; + sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13; + sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23; + sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33; + if (beta != 0.0) { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc]; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc]; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc]; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc]; + } else { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3; + } + } + for (; row < m; row += 1) { + double alphab = alpha * b[offsetb + row + col * ldb]; + double sum = 0.0; + int i = 0; + for (; i < row; i += 1) { + double aval = a[offseta + i + row * lda]; + c[offsetc + i + col * ldc] += alphab * aval; + sum += b[offsetb + i + col * ldb] * aval; + } + sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda]; + if (beta != 0.0) { + c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc]; + } else { + c[offsetc + row + col * ldc] = alpha * sum; + } + } + } + } + + protected void dsymmLL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + final int Srow = 4; + // C := alpha*A*B + beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = m - 1; + for (; row >= loopBound(m - 1, Srow); row -= 1) { + double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb]; + double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb]; + double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb]; + double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb]; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda]; + sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda]; + sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda]; + sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda]; + int i = row + 1; + for (; i < m; i += 1) { + double airow = a[offseta + i + row * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab0 * airow; + c[offsetc + i + (col + 1) * ldc] += alphab1 * airow; + c[offsetc + i + (col + 2) * ldc] += alphab2 * airow; + c[offsetc + i + (col + 3) * ldc] += alphab3 * airow; + sum0 += b[offsetb + i + (col + 0) * ldb] * airow; + sum1 += b[offsetb + i + (col + 1) * ldb] * airow; + sum2 += b[offsetb + i + (col + 2) * ldb] * airow; + sum3 += b[offsetb + i + (col + 3) * ldb] * airow; + } + if (beta != 0.0) { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3; + } + } + for (row -= Srow - 1; row >= 0; row -= Srow) { + double a00 = a[offseta + (row + 0) + (row + 0) * lda]; + double a10 = a[offseta + (row + 1) + (row + 0) * lda]; + double a11 = a[offseta + (row + 1) + (row + 1) * lda]; + double a20 = a[offseta + (row + 2) + (row + 0) * lda]; + double a21 = a[offseta + (row + 2) + (row + 1) * lda]; + double a22 = a[offseta + (row + 2) + (row + 2) * lda]; + double a30 = a[offseta + (row + 3) + (row + 0) * lda]; + double a31 = a[offseta + (row + 3) + (row + 1) * lda]; + double a32 = a[offseta + (row + 3) + (row + 2) * lda]; + double a33 = a[offseta + (row + 3) + (row + 3) * lda]; + double b00 = b[offsetb + (row + 0) + (col + 0) * ldb]; + double b10 = b[offsetb + (row + 1) + (col + 0) * ldb]; + double b20 = b[offsetb + (row + 2) + (col + 0) * ldb]; + double b30 = b[offsetb + (row + 3) + (col + 0) * ldb]; + double b01 = b[offsetb + (row + 0) + (col + 1) * ldb]; + double b11 = b[offsetb + (row + 1) + (col + 1) * ldb]; + double b21 = b[offsetb + (row + 2) + (col + 1) * ldb]; + double b31 = b[offsetb + (row + 3) + (col + 1) * ldb]; + double b02 = b[offsetb + (row + 0) + (col + 2) * ldb]; + double b12 = b[offsetb + (row + 1) + (col + 2) * ldb]; + double b22 = b[offsetb + (row + 2) + (col + 2) * ldb]; + double b32 = b[offsetb + (row + 3) + (col + 2) * ldb]; + double b03 = b[offsetb + (row + 0) + (col + 3) * ldb]; + double b13 = b[offsetb + (row + 1) + (col + 3) * ldb]; + double b23 = b[offsetb + (row + 2) + (col + 3) * ldb]; + double b33 = b[offsetb + (row + 3) + (col + 3) * ldb]; + double alphab00 = alpha * b00; + double alphab10 = alpha * b10; + double alphab20 = alpha * b20; + double alphab30 = alpha * b30; + double alphab01 = alpha * b01; + double alphab11 = alpha * b11; + double alphab21 = alpha * b21; + double alphab31 = alpha * b31; + double alphab02 = alpha * b02; + double alphab12 = alpha * b12; + double alphab22 = alpha * b22; + double alphab32 = alpha * b32; + double alphab03 = alpha * b03; + double alphab13 = alpha * b13; + double alphab23 = alpha * b23; + double alphab33 = alpha * b33; + double sum00 = 0.0; + double sum10 = 0.0; + double sum20 = 0.0; + double sum30 = 0.0; + double sum01 = 0.0; + double sum11 = 0.0; + double sum21 = 0.0; + double sum31 = 0.0; + double sum02 = 0.0; + double sum12 = 0.0; + double sum22 = 0.0; + double sum32 = 0.0; + double sum03 = 0.0; + double sum13 = 0.0; + double sum23 = 0.0; + double sum33 = 0.0; + sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30; + sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31; + sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32; + sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33; + sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30; + sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31; + sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32; + sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33; + sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30; + sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31; + sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32; + sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33; + sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30; + sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31; + sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32; + sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33; + int i = row + 4; + for (; i < m; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab00 * a0 + + alphab10 * a1 + + alphab20 * a2 + + alphab30 * a3; + c[offsetc + i + (col + 1) * ldc] += alphab01 * a0 + + alphab11 * a1 + + alphab21 * a2 + + alphab31 * a3; + c[offsetc + i + (col + 2) * ldc] += alphab02 * a0 + + alphab12 * a1 + + alphab22 * a2 + + alphab32 * a3; + c[offsetc + i + (col + 3) * ldc] += alphab03 * a0 + + alphab13 * a1 + + alphab23 * a2 + + alphab33 * a3; + double b0 = b[offsetb + i + (col + 0) * ldb]; + double b1 = b[offsetb + i + (col + 1) * ldb]; + double b2 = b[offsetb + i + (col + 2) * ldb]; + double b3 = b[offsetb + i + (col + 3) * ldb]; + sum00 += b0 * a0; + sum10 += b0 * a1; + sum20 += b0 * a2; + sum30 += b0 * a3; + sum01 += b1 * a0; + sum11 += b1 * a1; + sum21 += b1 * a2; + sum31 += b1 * a3; + sum02 += b2 * a0; + sum12 += b2 * a1; + sum22 += b2 * a2; + sum32 += b2 * a3; + sum03 += b3 * a0; + sum13 += b3 * a1; + sum23 += b3 * a2; + sum33 += b3 * a3; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc]; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc]; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc]; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc]; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33; + } + } + } + for (; col < n; col += 1) { + int row = m - 1; + for (; row >= loopBound(m - 1, Srow); row -= 1) { + double alphab0 = alpha * b[offsetb + row + col * ldb]; + double sum0 = 0.0; + sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda]; + int i = row + 1; + for (; i < m; i += 1) { + double a0 = a[offseta + i + row * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0; + sum0 += b[offsetb + i + col * ldb] * a0; + } + if (beta != 0.0) { + c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc]; + } else { + c[offsetc + row + col * ldc] = alpha * sum0; + } + } + for (row -= Srow - 1; row >= 0; row -= Srow) { + double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb]; + double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb]; + double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb]; + double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb]; + double a00 = a[offseta + (row + 0) + (row + 0) * lda]; + double a10 = a[offseta + (row + 1) + (row + 0) * lda]; + double a11 = a[offseta + (row + 1) + (row + 1) * lda]; + double a20 = a[offseta + (row + 2) + (row + 0) * lda]; + double a21 = a[offseta + (row + 2) + (row + 1) * lda]; + double a22 = a[offseta + (row + 2) + (row + 2) * lda]; + double a30 = a[offseta + (row + 3) + (row + 0) * lda]; + double a31 = a[offseta + (row + 3) + (row + 1) * lda]; + double a32 = a[offseta + (row + 3) + (row + 2) * lda]; + double a33 = a[offseta + (row + 3) + (row + 3) * lda]; + double b0 = b[offsetb + (row + 0) + col * ldb]; + double b1 = b[offsetb + (row + 1) + col * ldb]; + double b2 = b[offsetb + (row + 2) + col * ldb]; + double b3 = b[offsetb + (row + 3) + col * ldb]; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30; + sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31; + sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32; + sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33; + int i = row + 4; + for (; i < m; i += 1) { + double a0 = a[offseta + i + (row + 0) * lda]; + double a1 = a[offseta + i + (row + 1) * lda]; + double a2 = a[offseta + i + (row + 2) * lda]; + double a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0 + + alphab1 * a1 + + alphab2 * a2 + + alphab3 * a3; + double bicol = b[offsetb + i + col * ldb]; + sum0 += bicol * a0; + sum1 += bicol * a1; + sum2 += bicol * a2; + sum3 += bicol * a3; + } + if (beta != 0.0) { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc]; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc]; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc]; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc]; + } else { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3; + } + } + } + } + + protected void dsymmRU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + // C := alpha*B*A + beta*C + org.netlib.blas.Dsymm.dsymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void dsymmRL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + // C := alpha*B*A + beta*C + org.netlib.blas.Dsymm.dsymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + if (alpha == 0.0f) { + // C := beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = 0; + for (; row < m; row += 1) { + c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc]; + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < m; row += 1) { + c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc]; + } + } + } else if (lsame("L", side) && lsame("U", uplo)) { + ssymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("L", side) && lsame("L", uplo)) { + ssymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("R", side) && lsame("U", uplo)) { + ssymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } else if (lsame("R", side) && lsame("L", uplo)) { + ssymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + } + + protected void ssymmLU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + // C := alpha*A*B + beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + float sum30 = 0.0f; + float sum01 = 0.0f; + float sum11 = 0.0f; + float sum21 = 0.0f; + float sum31 = 0.0f; + float sum02 = 0.0f; + float sum12 = 0.0f; + float sum22 = 0.0f; + float sum32 = 0.0f; + float sum03 = 0.0f; + float sum13 = 0.0f; + float sum23 = 0.0f; + float sum33 = 0.0f; + float alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb]; + float alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb]; + float alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb]; + float alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb]; + float alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb]; + float alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb]; + float alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb]; + float alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb]; + float alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb]; + float alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb]; + float alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb]; + float alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb]; + float alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb]; + float alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb]; + float alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb]; + float alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb]; + int i = 0; + for (; i < row; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab00 * a0 + + alphab10 * a1 + + alphab20 * a2 + + alphab30 * a3; + c[offsetc + i + (col + 1) * ldc] += alphab01 * a0 + + alphab11 * a1 + + alphab21 * a2 + + alphab31 * a3; + c[offsetc + i + (col + 2) * ldc] += alphab02 * a0 + + alphab12 * a1 + + alphab22 * a2 + + alphab32 * a3; + c[offsetc + i + (col + 3) * ldc] += alphab03 * a0 + + alphab13 * a1 + + alphab23 * a2 + + alphab33 * a3; + float b0 = b[offsetb + i + (col + 0) * ldb]; + float b1 = b[offsetb + i + (col + 1) * ldb]; + float b2 = b[offsetb + i + (col + 2) * ldb]; + float b3 = b[offsetb + i + (col + 3) * ldb]; + sum00 += a0 * b0; + sum10 += a1 * b0; + sum20 += a2 * b0; + sum30 += a3 * b0; + sum01 += a0 * b1; + sum11 += a1 * b1; + sum21 += a2 * b1; + sum31 += a3 * b1; + sum02 += a0 * b2; + sum12 += a1 * b2; + sum22 += a2 * b2; + sum32 += a3 * b2; + sum03 += a0 * b3; + sum13 += a1 * b3; + sum23 += a2 * b3; + sum33 += a3 * b3; + } + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a01 = a[offseta + (i + 0) + (row + 1) * lda]; + float a02 = a[offseta + (i + 0) + (row + 2) * lda]; + float a03 = a[offseta + (i + 0) + (row + 3) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a12 = a[offseta + (i + 1) + (row + 2) * lda]; + float a13 = a[offseta + (i + 1) + (row + 3) * lda]; + float a22 = a[offseta + (i + 2) + (row + 2) * lda]; + float a23 = a[offseta + (i + 2) + (row + 3) * lda]; + float a33 = a[offseta + (i + 3) + (row + 3) * lda]; + float b00 = b[offsetb + (i + 0) + (col + 0) * ldb]; + float b10 = b[offsetb + (i + 1) + (col + 0) * ldb]; + float b20 = b[offsetb + (i + 2) + (col + 0) * ldb]; + float b30 = b[offsetb + (i + 3) + (col + 0) * ldb]; + float b01 = b[offsetb + (i + 0) + (col + 1) * ldb]; + float b11 = b[offsetb + (i + 1) + (col + 1) * ldb]; + float b21 = b[offsetb + (i + 2) + (col + 1) * ldb]; + float b31 = b[offsetb + (i + 3) + (col + 1) * ldb]; + float b02 = b[offsetb + (i + 0) + (col + 2) * ldb]; + float b12 = b[offsetb + (i + 1) + (col + 2) * ldb]; + float b22 = b[offsetb + (i + 2) + (col + 2) * ldb]; + float b32 = b[offsetb + (i + 3) + (col + 2) * ldb]; + float b03 = b[offsetb + (i + 0) + (col + 3) * ldb]; + float b13 = b[offsetb + (i + 1) + (col + 3) * ldb]; + float b23 = b[offsetb + (i + 2) + (col + 3) * ldb]; + float b33 = b[offsetb + (i + 3) + (col + 3) * ldb]; + sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30; + sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30; + sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30; + sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30; + sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31; + sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31; + sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31; + sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31; + sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32; + sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32; + sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32; + sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32; + sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33; + sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33; + sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33; + sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33; + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc]; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc]; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc]; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc]; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33; + } + } + for (; row < m; row += 1) { + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb]; + float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb]; + float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb]; + float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb]; + int i = 0; + for (; i < row; i += 1) { + float a0 = a[offseta + i + row * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab0 * a0; + c[offsetc + i + (col + 1) * ldc] += alphab1 * a0; + c[offsetc + i + (col + 2) * ldc] += alphab2 * a0; + c[offsetc + i + (col + 3) * ldc] += alphab3 * a0; + sum0 += b[offsetb + i + (col + 0) * ldb] * a0; + sum1 += b[offsetb + i + (col + 1) * ldb] * a0; + sum2 += b[offsetb + i + (col + 2) * ldb] * a0; + sum3 += b[offsetb + i + (col + 3) * ldb] * a0; + } + float a0 = a[offseta + i + row * lda]; + sum0 += b[offsetb + i + (col + 0) * ldb] * a0; + sum1 += b[offsetb + i + (col + 1) * ldb] * a0; + sum2 += b[offsetb + i + (col + 2) * ldb] * a0; + sum3 += b[offsetb + i + (col + 3) * ldb] * a0; + if (beta != 0.0f) { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3; + } + } + } + for (; col < n; col += 1) { + int row = 0; + for (; row < loopBound(m, 4); row += 4) { + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb]; + float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb]; + float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb]; + float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb]; + int i = 0; + for (; i < row; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0 + + alphab1 * a1 + + alphab2 * a2 + + alphab3 * a3; + float b0 = b[offsetb + i + col * ldb]; + sum0 += b0 * a0; + sum1 += b0 * a1; + sum2 += b0 * a2; + sum3 += b0 * a3; + } + float a00 = a[offseta + (i + 0) + (row + 0) * lda]; + float a01 = a[offseta + (i + 0) + (row + 1) * lda]; + float a02 = a[offseta + (i + 0) + (row + 2) * lda]; + float a03 = a[offseta + (i + 0) + (row + 3) * lda]; + float a11 = a[offseta + (i + 1) + (row + 1) * lda]; + float a12 = a[offseta + (i + 1) + (row + 2) * lda]; + float a13 = a[offseta + (i + 1) + (row + 3) * lda]; + float a22 = a[offseta + (i + 2) + (row + 2) * lda]; + float a23 = a[offseta + (i + 2) + (row + 3) * lda]; + float a33 = a[offseta + (i + 3) + (row + 3) * lda]; + float b0 = b[offsetb + (i + 0) + col * ldb]; + float b1 = b[offsetb + (i + 1) + col * ldb]; + float b2 = b[offsetb + (i + 2) + col * ldb]; + float b3 = b[offsetb + (i + 3) + col * ldb]; + sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03; + sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13; + sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23; + sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33; + if (beta != 0.0f) { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc]; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc]; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc]; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc]; + } else { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3; + } + } + for (; row < m; row += 1) { + float alphab = alpha * b[offsetb + row + col * ldb]; + float sum = 0.0f; + int i = 0; + for (; i < row; i += 1) { + float aval = a[offseta + i + row * lda]; + c[offsetc + i + col * ldc] += alphab * aval; + sum += b[offsetb + i + col * ldb] * aval; + } + sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda]; + if (beta != 0.0f) { + c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc]; + } else { + c[offsetc + row + col * ldc] = alpha * sum; + } + } + } + } + + protected void ssymmLL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + final int Srow = 4; + // C := alpha*A*B + beta*C + int col = 0; + for (; col < loopBound(n, 4); col += 4) { + int row = m - 1; + for (; row >= loopBound(m - 1, Srow); row -= 1) { + float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb]; + float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb]; + float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb]; + float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb]; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda]; + sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda]; + sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda]; + sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda]; + int i = row + 1; + for (; i < m; i += 1) { + float airow = a[offseta + i + row * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab0 * airow; + c[offsetc + i + (col + 1) * ldc] += alphab1 * airow; + c[offsetc + i + (col + 2) * ldc] += alphab2 * airow; + c[offsetc + i + (col + 3) * ldc] += alphab3 * airow; + sum0 += b[offsetb + i + (col + 0) * ldb] * airow; + sum1 += b[offsetb + i + (col + 1) * ldb] * airow; + sum2 += b[offsetb + i + (col + 2) * ldb] * airow; + sum3 += b[offsetb + i + (col + 3) * ldb] * airow; + } + if (beta != 0.0f) { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc]; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc]; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc]; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc]; + } else { + c[offsetc + row + (col + 0) * ldc] = alpha * sum0; + c[offsetc + row + (col + 1) * ldc] = alpha * sum1; + c[offsetc + row + (col + 2) * ldc] = alpha * sum2; + c[offsetc + row + (col + 3) * ldc] = alpha * sum3; + } + } + for (row -= Srow - 1; row >= 0; row -= Srow) { + float a00 = a[offseta + (row + 0) + (row + 0) * lda]; + float a10 = a[offseta + (row + 1) + (row + 0) * lda]; + float a11 = a[offseta + (row + 1) + (row + 1) * lda]; + float a20 = a[offseta + (row + 2) + (row + 0) * lda]; + float a21 = a[offseta + (row + 2) + (row + 1) * lda]; + float a22 = a[offseta + (row + 2) + (row + 2) * lda]; + float a30 = a[offseta + (row + 3) + (row + 0) * lda]; + float a31 = a[offseta + (row + 3) + (row + 1) * lda]; + float a32 = a[offseta + (row + 3) + (row + 2) * lda]; + float a33 = a[offseta + (row + 3) + (row + 3) * lda]; + float b00 = b[offsetb + (row + 0) + (col + 0) * ldb]; + float b10 = b[offsetb + (row + 1) + (col + 0) * ldb]; + float b20 = b[offsetb + (row + 2) + (col + 0) * ldb]; + float b30 = b[offsetb + (row + 3) + (col + 0) * ldb]; + float b01 = b[offsetb + (row + 0) + (col + 1) * ldb]; + float b11 = b[offsetb + (row + 1) + (col + 1) * ldb]; + float b21 = b[offsetb + (row + 2) + (col + 1) * ldb]; + float b31 = b[offsetb + (row + 3) + (col + 1) * ldb]; + float b02 = b[offsetb + (row + 0) + (col + 2) * ldb]; + float b12 = b[offsetb + (row + 1) + (col + 2) * ldb]; + float b22 = b[offsetb + (row + 2) + (col + 2) * ldb]; + float b32 = b[offsetb + (row + 3) + (col + 2) * ldb]; + float b03 = b[offsetb + (row + 0) + (col + 3) * ldb]; + float b13 = b[offsetb + (row + 1) + (col + 3) * ldb]; + float b23 = b[offsetb + (row + 2) + (col + 3) * ldb]; + float b33 = b[offsetb + (row + 3) + (col + 3) * ldb]; + float alphab00 = alpha * b00; + float alphab10 = alpha * b10; + float alphab20 = alpha * b20; + float alphab30 = alpha * b30; + float alphab01 = alpha * b01; + float alphab11 = alpha * b11; + float alphab21 = alpha * b21; + float alphab31 = alpha * b31; + float alphab02 = alpha * b02; + float alphab12 = alpha * b12; + float alphab22 = alpha * b22; + float alphab32 = alpha * b32; + float alphab03 = alpha * b03; + float alphab13 = alpha * b13; + float alphab23 = alpha * b23; + float alphab33 = alpha * b33; + float sum00 = 0.0f; + float sum10 = 0.0f; + float sum20 = 0.0f; + float sum30 = 0.0f; + float sum01 = 0.0f; + float sum11 = 0.0f; + float sum21 = 0.0f; + float sum31 = 0.0f; + float sum02 = 0.0f; + float sum12 = 0.0f; + float sum22 = 0.0f; + float sum32 = 0.0f; + float sum03 = 0.0f; + float sum13 = 0.0f; + float sum23 = 0.0f; + float sum33 = 0.0f; + sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30; + sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31; + sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32; + sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33; + sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30; + sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31; + sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32; + sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33; + sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30; + sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31; + sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32; + sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33; + sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30; + sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31; + sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32; + sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33; + int i = row + 4; + for (; i < m; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + (col + 0) * ldc] += alphab00 * a0 + + alphab10 * a1 + + alphab20 * a2 + + alphab30 * a3; + c[offsetc + i + (col + 1) * ldc] += alphab01 * a0 + + alphab11 * a1 + + alphab21 * a2 + + alphab31 * a3; + c[offsetc + i + (col + 2) * ldc] += alphab02 * a0 + + alphab12 * a1 + + alphab22 * a2 + + alphab32 * a3; + c[offsetc + i + (col + 3) * ldc] += alphab03 * a0 + + alphab13 * a1 + + alphab23 * a2 + + alphab33 * a3; + float b0 = b[offsetb + i + (col + 0) * ldb]; + float b1 = b[offsetb + i + (col + 1) * ldb]; + float b2 = b[offsetb + i + (col + 2) * ldb]; + float b3 = b[offsetb + i + (col + 3) * ldb]; + sum00 += b0 * a0; + sum10 += b0 * a1; + sum20 += b0 * a2; + sum30 += b0 * a3; + sum01 += b1 * a0; + sum11 += b1 * a1; + sum21 += b1 * a2; + sum31 += b1 * a3; + sum02 += b2 * a0; + sum12 += b2 * a1; + sum22 += b2 * a2; + sum32 += b2 * a3; + sum03 += b3 * a0; + sum13 += b3 * a1; + sum23 += b3 * a2; + sum33 += b3 * a3; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc]; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc]; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc]; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc]; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc]; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc]; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc]; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc]; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc]; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc]; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc]; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc]; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc]; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc]; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc]; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc]; + } else { + c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00; + c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10; + c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20; + c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30; + c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01; + c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11; + c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21; + c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31; + c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02; + c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12; + c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22; + c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32; + c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03; + c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13; + c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23; + c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33; + } + } + } + for (; col < n; col += 1) { + int row = m - 1; + for (; row >= loopBound(m - 1, Srow); row -= 1) { + float alphab0 = alpha * b[offsetb + row + col * ldb]; + float sum0 = 0.0f; + sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda]; + int i = row + 1; + for (; i < m; i += 1) { + float a0 = a[offseta + i + row * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0; + sum0 += b[offsetb + i + col * ldb] * a0; + } + if (beta != 0.0f) { + c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc]; + } else { + c[offsetc + row + col * ldc] = alpha * sum0; + } + } + for (row -= Srow - 1; row >= 0; row -= Srow) { + float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb]; + float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb]; + float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb]; + float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb]; + float a00 = a[offseta + (row + 0) + (row + 0) * lda]; + float a10 = a[offseta + (row + 1) + (row + 0) * lda]; + float a11 = a[offseta + (row + 1) + (row + 1) * lda]; + float a20 = a[offseta + (row + 2) + (row + 0) * lda]; + float a21 = a[offseta + (row + 2) + (row + 1) * lda]; + float a22 = a[offseta + (row + 2) + (row + 2) * lda]; + float a30 = a[offseta + (row + 3) + (row + 0) * lda]; + float a31 = a[offseta + (row + 3) + (row + 1) * lda]; + float a32 = a[offseta + (row + 3) + (row + 2) * lda]; + float a33 = a[offseta + (row + 3) + (row + 3) * lda]; + float b0 = b[offsetb + (row + 0) + col * ldb]; + float b1 = b[offsetb + (row + 1) + col * ldb]; + float b2 = b[offsetb + (row + 2) + col * ldb]; + float b3 = b[offsetb + (row + 3) + col * ldb]; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30; + sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31; + sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32; + sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33; + int i = row + 4; + for (; i < m; i += 1) { + float a0 = a[offseta + i + (row + 0) * lda]; + float a1 = a[offseta + i + (row + 1) * lda]; + float a2 = a[offseta + i + (row + 2) * lda]; + float a3 = a[offseta + i + (row + 3) * lda]; + c[offsetc + i + col * ldc] += alphab0 * a0 + + alphab1 * a1 + + alphab2 * a2 + + alphab3 * a3; + float bicol = b[offsetb + i + col * ldb]; + sum0 += bicol * a0; + sum1 += bicol * a1; + sum2 += bicol * a2; + sum3 += bicol * a3; + } + if (beta != 0.0f) { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc]; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc]; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc]; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc]; + } else { + c[offsetc + (row + 0) + col * ldc] = alpha * sum0; + c[offsetc + (row + 1) + col * ldc] = alpha * sum1; + c[offsetc + (row + 2) + col * ldc] = alpha * sum2; + c[offsetc + (row + 3) + col * ldc] = alpha * sum3; + } + } + } + } + + protected void ssymmRU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + // C := alpha*B*A + beta*C + org.netlib.blas.Ssymm.ssymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void ssymmRL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + // C := alpha*B*A + beta*C + org.netlib.blas.Ssymm.ssymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + if (alpha == 0.0) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } else if (lsame("U", uplo)) { + dsymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("L", uplo)) { + dsymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void dsymvU(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + double alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + double alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + double alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + double alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + double sumiy0 = 0.0; + double sumiy1 = 0.0; + double sumiy2 = 0.0; + double sumiy3 = 0.0; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + double a0 = a[offseta + row + (col + 0) * lda]; + double a1 = a[offseta + row + (col + 1) * lda]; + double a2 = a[offseta + row + (col + 2) * lda]; + double a3 = a[offseta + row + (col + 3) * lda]; + y[offsety + jy] += alphaxix0 * a0 + alphaxix1 * a1 + alphaxix2 * a2 + alphaxix3 * a3; + double x0 = x[offsetx + jx]; + sumiy0 += x0 * a0; + sumiy1 += x0 * a1; + sumiy2 += x0 * a2; + sumiy3 += x0 * a3; + } + double a00 = a[offseta + (row + 0) + (col + 0) * lda]; + double a01 = a[offseta + (row + 0) + (col + 1) * lda]; + double a02 = a[offseta + (row + 0) + (col + 2) * lda]; + double a03 = a[offseta + (row + 0) + (col + 3) * lda]; + double a11 = a[offseta + (row + 1) + (col + 1) * lda]; + double a12 = a[offseta + (row + 1) + (col + 2) * lda]; + double a13 = a[offseta + (row + 1) + (col + 3) * lda]; + double a22 = a[offseta + (row + 2) + (col + 2) * lda]; + double a23 = a[offseta + (row + 2) + (col + 3) * lda]; + double a33 = a[offseta + (row + 3) + (col + 3) * lda]; + double xjx0 = x[offsetx + jx + incx * 0]; + double xjx1 = x[offsetx + jx + incx * 1]; + double xjx2 = x[offsetx + jx + incx * 2]; + double xjx3 = x[offsetx + jx + incx * 3]; + sumiy0 += xjx0 * a00 + + xjx1 * a01 + + xjx2 * a02 + + xjx3 * a03; + sumiy1 += xjx0 * a01 + + xjx1 * a11 + + xjx2 * a12 + + xjx3 * a13; + sumiy2 += xjx0 * a02 + + xjx1 * a12 + + xjx2 * a22 + + xjx3 * a23; + sumiy3 += xjx0 * a03 + + xjx1 * a13 + + xjx2 * a23 + + xjx3 * a33; + if (beta != 0.0) { + y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3]; + } else { + y[offsety + iy + incy * 0] = alpha * sumiy0; + y[offsety + iy + incy * 1] = alpha * sumiy1; + y[offsety + iy + incy * 2] = alpha * sumiy2; + y[offsety + iy + incy * 3] = alpha * sumiy3; + } + } + for (; col < n; col += 1, ix += incx, iy += incy) { + double alphaxix = alpha * x[offsetx + ix]; + double sumiy = 0.0; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + double a0 = a[offseta + row + col * lda]; + y[offsety + jy] += alphaxix * a0; + sumiy += x[offsetx + jx] * a0; + } + sumiy += x[offsetx + jx] * a[offseta + row + col * lda]; + if (beta != 0.0) { + y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sumiy; + } + } + } + + protected void dsymvL(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) { + // y = beta * y + if (beta != 1.0) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0; + } + } + } + // y += alpha * A * x + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + double alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + double alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + double alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + double alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + double sumiy0 = 0.0; + double sumiy1 = 0.0; + double sumiy2 = 0.0; + double sumiy3 = 0.0; + double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda]; + double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda]; + double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda]; + double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda]; + double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda]; + double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda]; + double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda]; + double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda]; + double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda]; + double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda]; + double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)]; + double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)]; + double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)]; + double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)]; + sumiy0 += x0 * a00 + + x1 * a10 + + x2 * a20 + + x3 * a30; + sumiy1 += x0 * a10 + + x1 * a11 + + x2 * a21 + + x3 * a31; + sumiy2 += x0 * a20 + + x1 * a21 + + x2 * a22 + + x3 * a32; + sumiy3 += x0 * a30 + + x1 * a31 + + x2 * a32 + + x3 * a33; + int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + double a0 = a[offseta + row + (col + 0) * lda]; + double a1 = a[offseta + row + (col + 1) * lda]; + double a2 = a[offseta + row + (col + 2) * lda]; + double a3 = a[offseta + row + (col + 3) * lda]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + double xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + y[offsety + iy + incy * 0] += alpha * sumiy0; + y[offsety + iy + incy * 1] += alpha * sumiy1; + y[offsety + iy + incy * 2] += alpha * sumiy2; + y[offsety + iy + incy * 3] += alpha * sumiy3; + } + for (; col < n; col += 1, ix += incx, iy += incy) { + double alphaxix = alpha * x[offsetx + ix]; + double sumiy = 0.0; + sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda]; + int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * lda]; + sumiy += x[offsetx + jx] * a[offseta + row + col * lda]; + } + y[offsety + iy] += alpha * sumiy; + } + } + + protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + if (alpha == 0.0f) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + } else if (lsame("U", uplo)) { + ssymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } else if (lsame("L", uplo)) { + ssymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); + } + } + + protected void ssymvU(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + float alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + float alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + float alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + float alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + float sumiy0 = 0.0f; + float sumiy1 = 0.0f; + float sumiy2 = 0.0f; + float sumiy3 = 0.0f; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix0 * a[offseta + row + (col + 0) * lda] + + alphaxix1 * a[offseta + row + (col + 1) * lda] + + alphaxix2 * a[offseta + row + (col + 2) * lda] + + alphaxix3 * a[offseta + row + (col + 3) * lda]; + float xjx = x[offsetx + jx]; + sumiy0 += xjx * a[offseta + row + (col + 0) * lda]; + sumiy1 += xjx * a[offseta + row + (col + 1) * lda]; + sumiy2 += xjx * a[offseta + row + (col + 2) * lda]; + sumiy3 += xjx * a[offseta + row + (col + 3) * lda]; + } + float a00 = a[offseta + (row + 0) + (col + 0) * lda]; + float a01 = a[offseta + (row + 0) + (col + 1) * lda]; + float a02 = a[offseta + (row + 0) + (col + 2) * lda]; + float a03 = a[offseta + (row + 0) + (col + 3) * lda]; + float a11 = a[offseta + (row + 1) + (col + 1) * lda]; + float a12 = a[offseta + (row + 1) + (col + 2) * lda]; + float a13 = a[offseta + (row + 1) + (col + 3) * lda]; + float a22 = a[offseta + (row + 2) + (col + 2) * lda]; + float a23 = a[offseta + (row + 2) + (col + 3) * lda]; + float a33 = a[offseta + (row + 3) + (col + 3) * lda]; + float xjx0 = x[offsetx + jx + incx * 0]; + float xjx1 = x[offsetx + jx + incx * 1]; + float xjx2 = x[offsetx + jx + incx * 2]; + float xjx3 = x[offsetx + jx + incx * 3]; + sumiy0 += xjx0 * a00 + + xjx1 * a01 + + xjx2 * a02 + + xjx3 * a03; + sumiy1 += xjx0 * a01 + + xjx1 * a11 + + xjx2 * a12 + + xjx3 * a13; + sumiy2 += xjx0 * a02 + + xjx1 * a12 + + xjx2 * a22 + + xjx3 * a23; + sumiy3 += xjx0 * a03 + + xjx1 * a13 + + xjx2 * a23 + + xjx3 * a33; + if (beta != 0.0f) { + y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0]; + y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1]; + y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2]; + y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3]; + } else { + y[offsety + iy + incy * 0] = alpha * sumiy0; + y[offsety + iy + incy * 1] = alpha * sumiy1; + y[offsety + iy + incy * 2] = alpha * sumiy2; + y[offsety + iy + incy * 3] = alpha * sumiy3; + } + } + for (; col < n; col += 1, ix += incx, iy += incy) { + float alphaxix = alpha * x[offsetx + ix]; + float sumiy = 0.0f; + int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0; + for (; row < col; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * lda]; + sumiy += x[offsetx + jx] * a[offseta + row + col * lda]; + } + sumiy += x[offsetx + jx] * a[offseta + row + col * lda]; + if (beta != 0.0f) { + y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy]; + } else { + y[offsety + iy] = alpha * sumiy; + } + } + } + + protected void ssymvL(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) { + // y = beta * y + if (beta != 1.0f) { + for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) { + if (beta != 0.0f) { + y[offsety + iy] = beta * y[offsety + iy]; + } else { + y[offsety + iy] = 0.0f; + } + } + } + // y += alpha * A * x + int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0; + for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) { + float alphaxix0 = alpha * x[offsetx + ix + incx * 0]; + float alphaxix1 = alpha * x[offsetx + ix + incx * 1]; + float alphaxix2 = alpha * x[offsetx + ix + incx * 2]; + float alphaxix3 = alpha * x[offsetx + ix + incx * 3]; + float sumiy0 = 0.0f; + float sumiy1 = 0.0f; + float sumiy2 = 0.0f; + float sumiy3 = 0.0f; + float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda]; + float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda]; + float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda]; + float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda]; + float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda]; + float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda]; + float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda]; + float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda]; + float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda]; + float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda]; + float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)]; + float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)]; + float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)]; + float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)]; + sumiy0 += x0 * a00 + + x1 * a10 + + x2 * a20 + + x3 * a30; + sumiy1 += x0 * a10 + + x1 * a11 + + x2 * a21 + + x3 * a31; + sumiy2 += x0 * a20 + + x1 * a21 + + x2 * a22 + + x3 * a32; + sumiy3 += x0 * a30 + + x1 * a31 + + x2 * a32 + + x3 * a33; + int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + float a0 = a[offseta + row + (col + 0) * lda]; + float a1 = a[offseta + row + (col + 1) * lda]; + float a2 = a[offseta + row + (col + 2) * lda]; + float a3 = a[offseta + row + (col + 3) * lda]; + y[offsety + jy] += alphaxix0 * a0 + + alphaxix1 * a1 + + alphaxix2 * a2 + + alphaxix3 * a3; + float xjx = x[offsetx + jx]; + sumiy0 += xjx * a0; + sumiy1 += xjx * a1; + sumiy2 += xjx * a2; + sumiy3 += xjx * a3; + } + y[offsety + iy + incy * 0] += alpha * sumiy0; + y[offsety + iy + incy * 1] += alpha * sumiy1; + y[offsety + iy + incy * 2] += alpha * sumiy2; + y[offsety + iy + incy * 3] += alpha * sumiy3; + } + for (; col < n; col += 1, ix += incx, iy += incy) { + float alphaxix = alpha * x[offsetx + ix]; + float sumiy = 0.0f; + sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda]; + int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy; + for (; row < n; row += 1, jx += incx, jy += incy) { + y[offsety + jy] += alphaxix * a[offseta + row + col * lda]; + sumiy += x[offsetx + jx] * a[offseta + row + col * lda]; + } + y[offsety + iy] += alpha * sumiy; + } + } + + protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) { + org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + + protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) { + org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda); + } + + protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) { + org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + + protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) { + org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda); + } + + protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); + } + + protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) { + org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + + protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) { + org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc); + } + + protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx); + } + + protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) { + org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx); + } + + protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) { + org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) { + org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb); + } + + protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) { + org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) { + org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx); + } + + protected int idamaxK(int n, double[] x, int offsetx, int incx) { + return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx); + } + + protected int isamaxK(int n, float[] x, int offsetx, int incx) { + return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx); + } +} diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index c2f88222bf378f3ce23560d9a794fcffc22a6906..b9ea30bc3b471468fb30196b7e04f5eddb1bd8ac 100644 --- a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -1,5 +1,5 @@ /* -* Copyright (C) 2021. Huawei Technologies Co., Ltd. +* Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. @@ -70,14 +70,19 @@ private[spark] object BaggedPoint { subsamplingRate: Double, numSubsamples: Int, withReplacement: Boolean, - seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { - if (withReplacement) { - convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + seed: Long = Utils.random.nextLong(), + oneFeaturePerTree: Boolean = false): RDD[BaggedPoint[Datum]] = { + if (oneFeaturePerTree) { + convertToBaggedRDDWithoutSampling(input) } else { - if (numSubsamples == 1 && subsamplingRate == 1.0) { - convertToBaggedRDDWithoutSampling(input) + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) } else { - convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } } } } diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index dba8978acc3b6654166e588f8186a267bade2ebe..bca6d2611b08cfdb52b24adbd9ae667b1bf26662 100644 --- a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -1,5 +1,5 @@ /* -* Copyright (C) 2021. Huawei Technologies Co., Ltd. +* Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. @@ -59,7 +59,8 @@ private[spark] class DecisionTreeMetadata( val minInstancesPerNode: Int, val minInfoGain: Double, val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { + val numFeaturesPerNode: Int, + val oneFeaturePerTree: Boolean = false) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -212,10 +213,17 @@ private[spark] object DecisionTreeMetadata extends Logging { } } + val (newNumTrees, oneFeaturePerTree) = if (numTrees > 0) { + (numTrees, false) + } else { + (numFeatures, true) + + } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minInfoGain, newNumTrees, numFeaturesPerNode, + oneFeaturePerTree) } /** diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala b/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala new file mode 100644 index 0000000000000000000000000000000000000000..4d73e956c03d73a73122c16a8ef1f09dcc153fbe --- /dev/null +++ b/ml-core/src/main/scala/org/apache/spark/mllib/feature/VocabWord.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.feature + +/** + * Entry in vocabulary + */ +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen: Int +) diff --git a/ml-kernel-client-core/pom.xml b/ml-kernel-client-core/pom.xml index 561c79ecea28bae8c442b3275b70d8ac29561520..135e38565cb466eea2bf57691522206a23b1805b 100644 --- a/ml-kernel-client-core/pom.xml +++ b/ml-kernel-client-core/pom.xml @@ -2,12 +2,12 @@ org.apache.spark boostkit-ml - 2.1.0 + 2.2.0 4.0.0 boostkit-ml-kernel-client-core_2.11 - 2.1.0 + 2.2.0 ${project.artifactId} Spark ml core diff --git a/ml-kernel-client/pom.xml b/ml-kernel-client/pom.xml index 2be4e389d0f44c6f43394581e00c787b0663625f..b9101ff3db38155a6279839bf214a49ee847ca1c 100644 --- a/ml-kernel-client/pom.xml +++ b/ml-kernel-client/pom.xml @@ -2,12 +2,12 @@ org.apache.spark boostkit-ml - 2.1.0 + 2.2.0 4.0.0 boostkit-ml-kernel-client_2.11 - 2.1.0 + 2.2.0 ${project.artifactId} Spark ml core diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala new file mode 100644 index 0000000000000000000000000000000000000000..ef23ebb90929d13721cf4466adfbb1f9398c6b45 --- /dev/null +++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala @@ -0,0 +1,262 @@ +// scalastyle:off header.matches +/* +* Copyright (C) 2022. Huawei Technologies Co., Ltd. +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +* */ +/* + * This file to You under the Apache License, Version 2.0; + * you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.param.shared.HasWeightCol +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.tree.configuration.{Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType + + +private[ml] trait DecisionTreeBucketizerParams extends Params + with DecisionTreeClassifierParams with HasWeightCol { + + /** + * Param for bucketedFeatures column name. + * @group param + */ + final val bucketedFeaturesCol: Param[String] = + new Param[String](this, "bucketedFeaturesCol", "bucketedFeatures column name") + + setDefault(bucketedFeaturesCol, "bucketedFeatures") + + /** @group getParam */ + final def getBucketedFeaturesCol: String = $(bucketedFeaturesCol) +} + +/** + * Decision tree bucketing algorithm for data discretization. + */ +@Since("1.4.0") +class DecisionTreeBucketizer @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[DecisionTreeBucketModel] + with DecisionTreeBucketizerParams with DecisionTreeClassifierParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("dtb")) + + def setLabelCol(value: String): this.type = null + + def setFeaturesCol(value: String): this.type = null + + def setBucketedFeaturesCol(value: String): this.type = null + + // Override parameter setters from parent trait for Java API compatibility. + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = null + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = null + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = null + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = null + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = null + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = null + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = null + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = null + + /** @group setParam */ + @Since("1.6.0") + override def setSeed(value: Long): this.type = null + + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * + * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) + * and features (`Vector`). + * @param numClasses Number of classes label can take. Labels must be integers in the range + * [0, numClasses). + * @note Throws `SparkException` if any label is a non-integer or is negative + */ + private[ml] def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = + null + + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in `extractLabeledPoints()`. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + private[ml] def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = 0 + + override def transformSchema(schema: StructType): StructType = null + + override def fit(dataset: Dataset[_]): DecisionTreeBucketModel = null + + private[ml] def train(dataset: Dataset[_]): DecisionTreeBucketModel = null + + /** (private[ml]) Train decision trees on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeBucketModel = null + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = null + + @Since("1.4.1") + override def copy(extra: ParamMap): DecisionTreeBucketizer = null +} + +@Since("1.4.0") +object DecisionTreeBucketizer extends DefaultParamsReadable[DecisionTreeBucketizer] { + /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") + final val supportedImpurities: Array[String] = null + + @Since("2.0.0") + override def load(path: String): DecisionTreeBucketizer = null +} + +/** + * Decision tree bucket model for data discretization. + * @param _trees Decision trees of all features. + */ +@Since("1.4.0") +class DecisionTreeBucketModel private[ml] ( + @Since("1.5.0") override val uid: String, + private val _trees: Array[DecisionTreeClassificationModel], + val numFeatures: Int, + val numClasses: Int) + extends Model[DecisionTreeBucketModel] + with DecisionTreeBucketizerParams with DecisionTreeClassifierParams + with TreeEnsembleModel[DecisionTreeClassificationModel] + with MLWritable with Serializable { + + require(_trees.nonEmpty, "DecisionTreeBucketModel requires at least 1 tree.") + + /** + * Construct a decision tree bucket model, with all trees weighted equally. + * + * @param trees Component trees + */ + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = + this(Identifiable.randomUID("dtb"), trees, numFeatures, numClasses) + + def getNumTrees: Int = 0 + + @Since("1.4.0") + override def trees: Array[DecisionTreeClassificationModel] = null + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = null + + @Since("1.4.0") + override def treeWeights: Array[Double] = null + + override def transformSchema(schema: StructType): StructType = null + + override def transform(dataset: Dataset[_]): DataFrame = null + + @Since("1.4.0") + override def copy(extra: ParamMap): DecisionTreeBucketModel = null + + @Since("1.4.0") + override def toString: String = null + + /** (private[ml]) Convert to a model in the old API */ + def toOld: OldRandomForestModel = null + + @Since("2.0.0") + override def write: MLWriter = null +} + +@Since("2.0.0") +object DecisionTreeBucketModel extends MLReadable[DecisionTreeBucketModel] { + + @Since("2.0.0") + override def read: MLReader[DecisionTreeBucketModel] = null + + @Since("2.0.0") + override def load(path: String): DecisionTreeBucketModel = null + + private[DecisionTreeBucketModel] + class DecisionTreeBucketModelWriter(instance: DecisionTreeBucketModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = {} + } + + private class DecisionTreeBucketModelReader + extends MLReader[DecisionTreeBucketModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeBucketModel].getName + private val treeClassName = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): DecisionTreeBucketModel = null + } + + /** Convert a model from the old API */ + private[ml] def fromOld( + oldModel: OldRandomForestModel, + parent: DecisionTreeBucketizer, + categoricalFeatures: Map[Int, Int], + numClasses: Int, + numFeatures: Int = -1): DecisionTreeBucketModel = null +} diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala index ed245f91c61fac72bc2ad273f2026c18af656171..1fce387cfca7c70c6e8747e95f0313ef7ddf4799 100644 --- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala +++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala @@ -1,10 +1,11 @@ // scalastyle:off header.matches /* -* Copyright (C) 2021. Huawei Technologies Co., Ltd. +* Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * */ +// scalastyle:off header.matches /* * This file to You under the Apache License, Version 2.0; * you may not use this file except in compliance with @@ -61,6 +62,11 @@ object RFUtils extends Logging { true } + def isValidSample(baggedPoint: BaggedPoint[TreePointX], + groupInfo: GroupInfo, treeIndex: Int, id: Short): Boolean = { + true + } + def isValidNodeInfo(nodeInfo: NodeIndexInfo, agg: Array[DTStatsAggregator]): Boolean = { true } diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd75be29cdaedca06c9f9dfb70f410b79ca079d6 --- /dev/null +++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala @@ -0,0 +1,42 @@ +// scalastyle:off +/* +* Copyright (C) 2022. Huawei Technologies Co., Ltd. +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +* */ +/* + * This file to You under the Apache License, Version 2.0; + * you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package org.apache.spark.mllib.feature + +import scala.collection.mutable + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD + +class Word2VecSGHS( + val minCount: Int, + val window: Int, + val vectorSize: Int, + val vocabSize: Int, + val trainWordsCount: Long, + val learningRate: Double, + val numIterations: Int, + val seed: Long, + val maxSentenceLength: Int, + val regularization: Float, + val repetition: Int) extends Serializable with Logging { + + def fit[S <: Iterable[String]]( + dataset: RDD[S], + bcExpTable: Broadcast[Array[Float]], + bcVocab: Broadcast[Array[VocabWord]], + bcVocabHash: Broadcast[mutable.HashMap[String, Int]]): Array[Float] = { + null + } +} diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j-example/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j-example/pom.xml index 163ad377f3cf9868232f57530906972c92daadb7..562a3abf6dda9311ae3d6d0f071d4d48aa308478 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j-example/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j-example/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j-example_2.11 - 2.1.0 + 2.2.0 jar diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j-flink/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j-flink/pom.xml index 24347b5151a2eb2252877ff5bf2784538f636e91..de7326d2daa7831b8cb8cda8d4bc0c8e5f1bdb66 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j-flink/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j-flink/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j-flink_2.11 - 2.1.0 + 2.2.0 diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-client/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-client/pom.xml index bc2905cfb6f5df6af4023fdef3708474ad38e65f..7d114947dcce2a08bee515a839d7fd39a4d5a64b 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-client/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-client/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j-spark-client_2.11 - 2.1.0 + 2.2.0 src/main/scala diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-kernel-client/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-kernel-client/pom.xml index 8aed1cce23fc6359c6f2c9442b9a178dc5ffbb8d..789b1bbf403fd00cf51b3421cb0913ba0a638a5f 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-kernel-client/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark-kernel-client/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j-spark-kernel-client_2.11 - 2.1.0 + 2.2.0 src/main/scala diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/pom.xml index d1b3ea3b0d5b7b73154b7069bc38234004b0011e..97e6ccd7c2a2dedb4692220a5c54008b076d2a09 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j-spark/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j-spark2.4.6_2.11 - 2.1.0 + 2.2.0 diff --git a/ml-xgboost/jvm-packages/boostkit-xgboost4j/pom.xml b/ml-xgboost/jvm-packages/boostkit-xgboost4j/pom.xml index 68472ff69f3c099692dc7c146d8b3e7248b8c924..4b05dbc1add5643bbcc329f3cf7d784359981a86 100644 --- a/ml-xgboost/jvm-packages/boostkit-xgboost4j/pom.xml +++ b/ml-xgboost/jvm-packages/boostkit-xgboost4j/pom.xml @@ -6,10 +6,10 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 boostkit-xgboost4j_2.11 - 2.1.0 + 2.2.0 jar diff --git a/ml-xgboost/jvm-packages/pom.xml b/ml-xgboost/jvm-packages/pom.xml index 20ac113d3acc77f34a13650319dbed5de768e131..b6f30b8bf9875190c75eb393a43014a9505e51b4 100644 --- a/ml-xgboost/jvm-packages/pom.xml +++ b/ml-xgboost/jvm-packages/pom.xml @@ -6,7 +6,7 @@ ml.dmlc boostkit-xgboost-jvm_2.11 - 2.1.0 + 2.2.0 pom Boostkit XGBoost JVM Package Boostkit JVM Package for XGBoost diff --git a/pom.xml b/pom.xml index 67ad24cc404ff4c118965559f444a2ee1266ac06..66ed59ce2da7cf0b7ce15d4ddb1beb6376cf5e13 100644 --- a/pom.xml +++ b/pom.xml @@ -2,7 +2,7 @@ 4.0.0 org.apache.spark boostkit-ml - 2.1.0 + 2.2.0 ${project.artifactId} Spark ml algo 2020