1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTClassifierParams with DefaultParamsWritable with Logging {
override def setMaxDepth(value: Int): this.type = set(maxDepth, value) override def setMaxBins(value: Int): this.type = set(maxBins, value) override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) override def setSeed(value: Long): this.type = set(seed, value) override def setMaxIter(value: Int): this.type = set(maxIter, value) override def setStepSize(value: Double): this.type = set(stepSize, value) def setLossType(value: String): this.type = set(lossType, value)
override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) instr.logNumFeatures(numFeatures) instr.logNumClasses(2)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m }
@Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) }
|