Class DMatrix

java.lang.Object
ml.dmlc.xgboost4j.java.DMatrix

public class DMatrix extends Object
DMatrix for xgboost.
  • Nested Class Summary

    Nested Classes
    Modifier and Type
    Class
    Description
    static enum 
    sparse matrix type (CSR or CSC)
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    protected long
     
  • Constructor Summary

    Constructors
    Modifier
    Constructor
    Description
     
    DMatrix(float[] data, int nrow, int ncol)
    create DMatrix from dense matrix
     
    DMatrix(float[] data, int nrow, int ncol, float missing)
    create DMatrix from dense matrix
    protected
    DMatrix(long handle)
    used for DMatrix slice
     
    DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st)
    Deprecated.
     
    DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam)
    Create DMatrix from Sparse matrix in CSR/CSC format.
     
    DMatrix(String dataPath)
    Create DMatrix by loading libsvm file from dataPath
     
    DMatrix(Iterator<ml.dmlc.xgboost4j.LabeledPoint> iter, String cacheInfo)
    Create DMatrix from iterator.
  • Method Summary

    Modifier and Type
    Method
    Description
    void
     
    protected void
     
    private static float[]
    flatten(float[][] mat)
    flatten a mat to array
    float[]
    Get base margin of the DMatrix.
    private float[]
     
    long
    Get the handle
    private int[]
     
    float[]
    get label values
    float[]
    get weight of the DMatrix
    long
    get the row number of DMatrix
    void
    saveBinary(String filePath)
    save DMatrix to filePath
    void
    setBaseMargin(float[] baseMargin)
    Set base margin (initial prediction).
    void
    setBaseMargin(float[][] baseMargin)
    Set base margin (initial prediction).
    void
    setGroup(int[] group)
    Set group sizes of DMatrix (used for ranking)
    void
    setLabel(float[] labels)
    set label of dmatrix
    void
    setWeight(float[] weights)
    set weight of each instance
    slice(int[] rowIndex)
    Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.

    Methods inherited from class java.lang.Object

    clone, equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Field Details

    • handle

      protected long handle
  • Constructor Details

    • DMatrix

      public DMatrix(Iterator<ml.dmlc.xgboost4j.LabeledPoint> iter, String cacheInfo) throws XGBoostError
      Create DMatrix from iterator.
      Parameters:
      iter - The data iterator of mini batch to provide the data.
      cacheInfo - Cache path information, used for external memory setting, can be null.
      Throws:
      XGBoostError
    • DMatrix

      public DMatrix(String dataPath) throws XGBoostError
      Create DMatrix by loading libsvm file from dataPath
      Parameters:
      dataPath - The path to the data.
      Throws:
      XGBoostError
    • DMatrix

      @Deprecated public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st) throws XGBoostError
      Deprecated.
      Create DMatrix from Sparse matrix in CSR/CSC format.
      Parameters:
      headers - The row index of the matrix.
      indices - The indices of presenting entries.
      data - The data content.
      st - Type of sparsity.
      Throws:
      XGBoostError
    • DMatrix

      public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam) throws XGBoostError
      Create DMatrix from Sparse matrix in CSR/CSC format.
      Parameters:
      headers - The row index of the matrix.
      indices - The indices of presenting entries.
      data - The data content.
      st - Type of sparsity.
      shapeParam - when st is CSR, it specifies the column number, otherwise it is taken as row number
      Throws:
      XGBoostError
    • DMatrix

      public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError
      create DMatrix from dense matrix
      Parameters:
      data - data values
      nrow - number of rows
      ncol - number of columns
      Throws:
      XGBoostError - native error
    • DMatrix

      public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError
      create DMatrix from dense matrix
      Parameters:
      data - data values
      nrow - number of rows
      ncol - number of columns
      missing - the specified value to represent the missing value
      Throws:
      XGBoostError
    • DMatrix

      protected DMatrix(long handle)
      used for DMatrix slice
  • Method Details

    • setLabel

      public void setLabel(float[] labels) throws XGBoostError
      set label of dmatrix
      Parameters:
      labels - labels
      Throws:
      XGBoostError - native error
    • setWeight

      public void setWeight(float[] weights) throws XGBoostError
      set weight of each instance
      Parameters:
      weights - weights
      Throws:
      XGBoostError - native error
    • setBaseMargin

      public void setBaseMargin(float[] baseMargin) throws XGBoostError
      Set base margin (initial prediction). The margin must have the same number of elements as the number of rows in this matrix.
      Throws:
      XGBoostError
    • setBaseMargin

      public void setBaseMargin(float[][] baseMargin) throws XGBoostError
      Set base margin (initial prediction).
      Throws:
      XGBoostError
    • setGroup

      public void setGroup(int[] group) throws XGBoostError
      Set group sizes of DMatrix (used for ranking)
      Parameters:
      group - group size as array
      Throws:
      XGBoostError - native error
    • getFloatInfo

      private float[] getFloatInfo(String field) throws XGBoostError
      Throws:
      XGBoostError
    • getIntInfo

      private int[] getIntInfo(String field) throws XGBoostError
      Throws:
      XGBoostError
    • getLabel

      public float[] getLabel() throws XGBoostError
      get label values
      Returns:
      label
      Throws:
      XGBoostError - native error
    • getWeight

      public float[] getWeight() throws XGBoostError
      get weight of the DMatrix
      Returns:
      weights
      Throws:
      XGBoostError - native error
    • getBaseMargin

      public float[] getBaseMargin() throws XGBoostError
      Get base margin of the DMatrix.
      Throws:
      XGBoostError
    • slice

      public DMatrix slice(int[] rowIndex) throws XGBoostError
      Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
      Parameters:
      rowIndex - row index
      Returns:
      sliced new DMatrix
      Throws:
      XGBoostError - native error
    • rowNum

      public long rowNum() throws XGBoostError
      get the row number of DMatrix
      Returns:
      number of rows
      Throws:
      XGBoostError - native error
    • saveBinary

      public void saveBinary(String filePath)
      save DMatrix to filePath
    • getHandle

      public long getHandle()
      Get the handle
    • flatten

      private static float[] flatten(float[][] mat)
      flatten a mat to array
    • finalize

      protected void finalize()
      Overrides:
      finalize in class Object
    • dispose

      public void dispose()