Developing ML in Java

Machine Learning Libraries in Java

Weka

Developing ML in Java 1
A screenshot of the Weka GUI toolkit.

Apache Mahout

Developing ML in Java 2

Deeplearning4j

Developing ML in Java 3

Mallet

Developing ML in Java 4

Spark MLlib

Developing ML in Java 5

The Encog Machine Learning Framework

Developing ML in Java 3

MOA

Developing ML in Java 7

Weka Example:

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class Main {

    public static void main(String[] args) throws Exception {
        // Specifying the datasource
        DataSource dataSource = new DataSource("data.arff");
        // Loading the dataset
        Instances dataInstances = dataSource.getDataSet();
        // Displaying the number of instances
        log.info("The number of loaded instances is: " + dataInstances.numInstances());

        log.info("data:" + dataInstances.toString());
    }
}
log.info("The number of attributes in the dataset: " + dataInstances.numAttributes());
// Identifying the label index
dataInstances.setClassIndex(dataInstances.numAttributes() - 1);
// Getting the number of 
log.info("The number of classes: " + dataInstances.numClasses());
// Creating a decision tree classifier
J48 treeClassifier = new J48();
treeClassifier.setOptions(new String[] { "-U" });
treeClassifier.buildClassifier(dataInstances);
plas <= 127
|   mass <= 26.4
|   |   preg <= 7: tested_negative (117.0/1.0)
|   |   preg > 7
|   |   |   mass <= 0: tested_positive (2.0)
|   |   |   mass > 0: tested_negative (13.0)
|   mass > 26.4
|   |   age <= 28: tested_negative (180.0/22.0)
|   |   age > 28
|   |   |   plas <= 99: tested_negative (55.0/10.0)
|   |   |   plas > 99
|   |   |   |   pedi <= 0.56: tested_negative (84.0/34.0)
|   |   |   |   pedi > 0.56
|   |   |   |   |   preg <= 6
|   |   |   |   |   |   age <= 30: tested_positive (4.0)
|   |   |   |   |   |   age > 30
|   |   |   |   |   |   |   age <= 34: tested_negative (7.0/1.0)
|   |   |   |   |   |   |   age > 34
|   |   |   |   |   |   |   |   mass <= 33.1: tested_positive (6.0)
|   |   |   |   |   |   |   |   mass > 33.1: tested_negative (4.0/1.0)
|   |   |   |   |   preg > 6: tested_positive (13.0)
plas > 127
|   mass <= 29.9
|   |   plas <= 145: tested_negative (41.0/6.0)
|   |   plas > 145
|   |   |   age <= 25: tested_negative (4.0)
|   |   |   age > 25
|   |   |   |   age <= 61
|   |   |   |   |   mass <= 27.1: tested_positive (12.0/1.0)
|   |   |   |   |   mass > 27.1
|   |   |   |   |   |   pres <= 82
|   |   |   |   |   |   |   pedi <= 0.396: tested_positive (8.0/1.0)
|   |   |   |   |   |   |   pedi > 0.396: tested_negative (3.0)
|   |   |   |   |   |   pres > 82: tested_negative (4.0)
|   |   |   |   age > 61: tested_negative (4.0)
|   mass > 29.9
|   |   plas <= 157
|   |   |   pres <= 61: tested_positive (15.0/1.0)
|   |   |   pres > 61
|   |   |   |   age <= 30: tested_negative (40.0/13.0)
|   |   |   |   age > 30: tested_positive (60.0/17.0)
|   |   plas > 157: tested_positive (92.0/12.0)Number of Leaves  :  22Size of the tree :  43

Deeplearning4j Example:

DataSetIterator MNISTTrain = new MnistDataSetIterator(batchSize,true,seed);
DataSetIterator MNISTTest = new MnistDataSetIterator(batchSize,false,seed);
log.info("The number of total labels found in the training dataset " + MNISTTrain.totalOutcomes());
log.info("The number of total labels found in the test dataset " + MNISTTest.totalOutcomes());
// Building the CNN model
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed) // random seed
        .l2(0.0005) // regularization
        .weightInit(WeightInit.XAVIER) // initialization of the weight scheme
        .updater(new Adam(1e-3)) // Setting the optimization algorithm
        .list()
        .layer(new ConvolutionLayer.Builder(5, 5)
                //Setting the stride, the kernel size, and the activation function.
                .nIn(nChannels)
                .stride(1,1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
        .layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
                .kernelSize(2,2)
                .stride(2,2)
                .build())
        .layer(new ConvolutionLayer.Builder(5, 5)
                // Setting the stride, kernel size, and the activation function.
                .stride(1,1)
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
        .layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
                .kernelSize(2,2)
                .stride(2,2)
                .build())
        .layer(new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500).build())
        .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(outputNum)
                .activation(Activation.SOFTMAX)
                .build())
        // the final output layer is 28x28 with a depth of 1.
        .setInputType(InputType.convolutionalFlat(28,28,1))
        .build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
// initialize the model weights.
model.init();

log.info("Step2: start training the model");
//Setting a listener every 10 iterations and evaluate on test set on every epoch
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(MNISTTest, 1, InvocationType.EPOCH_END));
// Training the model
model.fit(MNISTTrain, nEpochs);
=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  977    0    0    0    0    0    1    1    1    0 | 0 = 0
    0 1131    0    1    0    1    2    0    0    0 | 1 = 1
    1    2 1019    3    0    0    0    3    4    0 | 2 = 2
    0    0    1 1004    0    1    0    1    3    0 | 3 = 3
    0    0    0    0  977    0    2    0    1    2 | 4 = 4
    1    0    0    9    0  879    1    0    1    1 | 5 = 5
    4    2    0    0    1    1  949    0    1    0 | 6 = 6
    0    4    2    1    1    0    0 1018    1    1 | 7 = 7
    2    0    3    1    0    1    1    2  962    2 | 8 = 8
    0    2    0    2   11    2    0    3    2  987 | 9 = 9

Mallet Example:

ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add( new CharSequenceLowercase() );
pipeList.add( new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")) );
// Setting the dictionary of the stop words
URL stopWordsFile = getClass().getClassLoader().getResource("stoplists/en.txt");
pipeList.add( new TokenSequenceRemoveStopwords(new File(stopWordsFile.toURI()), "UTF-8", false, false, false) );

pipeList.add( new TokenSequence2FeatureSequence() );
InstanceList instances = new InstanceList (new SerialPipes(pipeList));
URL inputFileURL = getClass().getClassLoader().getResource(inputFile);
Reader fileReader = new InputStreamReader(new FileInputStream(new File(inputFileURL.toURI())), "UTF-8");
instances.addThruPipe(new CsvIterator (fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
        3, 2, 1)); // data, label, name fields
log.info("The number of instances found in the input file is: " + instances.size());

int numTopics = 100;// defining the model 
ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01);
// adding the instances to the model
model.addInstances(instances);
model.setNumThreads(2);
model.setNumIterations(50);
model.estimate();

This article has been published from the source link without modifications to the text. Only the headline has been changed.

Source link

Most Popular