Machine Learning (ML) has bought significant promises in different fields in both academia and industry. Day by day, ML has grown its engagement in a comprehensive list of applications such as image, speech recognition, pattern recognition, optimization, natural language processing, and recommendations, and so many others.
Programming computers to learn from experience should eventually eliminate the need for much of this detailed programming effort. — Arthur Samuel 1959.
Machine Learning can be divided into four main techniques: regression, classification, clustering, and reinforcement learning. Those techniques solve problems with different natures in mainly two forms: supervised and unsupervised learning. Supervised learning requires the data to be labeled and prepared ahead of training the model. Unsupervised learning comes in handy to handle unlabeled data or data that has unknown characteristics. This article does not describe ML’s concepts or go in in-depth describing the terms used in this field. If you are entirely new, please look at my previous article starting your ML learning journey.
Machine Learning Libraries in Java
Here is a list of well-known libraries in Java for ML. We will describe them one by one and give real-world examples using some of those frameworks.
- Weka
- Apache Mahout
- Deeplearning4j
- Mallet
- Spark MLlib
- The Encog Machine Learning Framework
- MOA
Next to each library, the following icons would indicate the major categories of algorithms provided in each framework by default.
Weka
Weka is an open-source library developed by the University of Waikato in New Zeland. Weka is written in Java, and it is very well-known for general-purpose machine learning. Weka provides a data file format, called ARFF. ARFF is split into two parts: header and the actual data. The header describes the attributes and their data types.
Apache Mahout
Apache Mahout provides a scalable machine learning library. Mahout uses the MapReduce paradigm and can be used for classification, collaborative filtering, and clustering. Mahout utilizes Apache Hadoop to process multiple parallel tasks. In addition to classification and clustering, Mahout provides recommendation algorithms such as collaborative filtering, facilitating the scalability of building your model quickly.
Deeplearning4j
Deeplearning4j is another java library focusing on deep learning. It is one great open-source libraries of deep learning for Java. It is also written in Scala and Java and can be integrated with Hadoop and Spark, providing high processing capabilities. The current release is in the Beta version but comes with excellent documentation and quick start examples (click here).
Mallet
Mallet stands for Machine Learning for Language Toolkit. It is one of few specialized toolkits for natural language processing. It provides capabilities for topic modeling, document classification, clustering, and information extraction. With Mallet, we can ML models to process textual documents.
Spark MLlib
Spark is very well known to accelerate the scalability and overall performance of processing a massive amount of data. Spark MLlib also has high power algorithms to run on spark and plugged into Hadoop workflows.
The Encog Machine Learning Framework
Encog is a Java and C# framework for ML. Envog has libraries for building SVM, NN, Bayesian Networks, HMM, and genetic algorithms. Encog has started as a research project and got almost a thousand citations on Google Scholar.
MOA
Massive Online Analysis (MOA) provides algorithms for classification, regression, clustering, and recommendations. It also provides libraries for outlier detection and drift detection. It is designed for real-time processing on a stream of produced data.
Weka Example:
We are going to use a small diabetes dataset. We will first load the data using Weka:
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
public class Main {
public static void main(String[] args) throws Exception {
// Specifying the datasource
DataSource dataSource = new DataSource("data.arff");
// Loading the dataset
Instances dataInstances = dataSource.getDataSet();
// Displaying the number of instances
log.info("The number of loaded instances is: " + dataInstances.numInstances());
log.info("data:" + dataInstances.toString());
}
}
There are 768 instances in the dataset. Let’s see how to get the number of attributes (features), which should be 9.
log.info("The number of attributes in the dataset: " + dataInstances.numAttributes());
Before building any model, we want to identify which column is the target column and see how many classes are found in this column:
// Identifying the label index
dataInstances.setClassIndex(dataInstances.numAttributes() - 1);
// Getting the number of
log.info("The number of classes: " + dataInstances.numClasses());
After loading the dataset and identifying our target attribute, the time now is for building the model. Let’s make a simple tree classifier, J48.
// Creating a decision tree classifier
J48 treeClassifier = new J48();
treeClassifier.setOptions(new String[] { "-U" });
treeClassifier.buildClassifier(dataInstances);
In the three lines above, we specified an option to indicate an unpruned tree and provided the data instances for model training. If we print the tree structure of the generated model after training, we can follow how the model internally built its rules:
plas <= 127 | mass <= 26.4 | | preg <= 7: tested_negative (117.0/1.0) | | preg > 7 | | | mass <= 0: tested_positive (2.0) | | | mass > 0: tested_negative (13.0) | mass > 26.4 | | age <= 28: tested_negative (180.0/22.0) | | age > 28 | | | plas <= 99: tested_negative (55.0/10.0) | | | plas > 99 | | | | pedi <= 0.56: tested_negative (84.0/34.0) | | | | pedi > 0.56 | | | | | preg <= 6 | | | | | | age <= 30: tested_positive (4.0) | | | | | | age > 30 | | | | | | | age <= 34: tested_negative (7.0/1.0) | | | | | | | age > 34 | | | | | | | | mass <= 33.1: tested_positive (6.0) | | | | | | | | mass > 33.1: tested_negative (4.0/1.0) | | | | | preg > 6: tested_positive (13.0) plas > 127 | mass <= 29.9 | | plas <= 145: tested_negative (41.0/6.0) | | plas > 145 | | | age <= 25: tested_negative (4.0) | | | age > 25 | | | | age <= 61 | | | | | mass <= 27.1: tested_positive (12.0/1.0) | | | | | mass > 27.1 | | | | | | pres <= 82 | | | | | | | pedi <= 0.396: tested_positive (8.0/1.0) | | | | | | | pedi > 0.396: tested_negative (3.0) | | | | | | pres > 82: tested_negative (4.0) | | | | age > 61: tested_negative (4.0) | mass > 29.9 | | plas <= 157 | | | pres <= 61: tested_positive (15.0/1.0) | | | pres > 61 | | | | age <= 30: tested_negative (40.0/13.0) | | | | age > 30: tested_positive (60.0/17.0) | | plas > 157: tested_positive (92.0/12.0)Number of Leaves : 22Size of the tree : 43
Deeplearning4j Example:
This example will build a Convolution Neural Network (CNN) model to classify the MNIST library. If you are not familiar with MNIST or how the CNN works to classify the handwritten digits, I recommend you have a quick look at my earlier post, which describes these aspects in detail.
As always, we will load the dataset and display its size.
DataSetIterator MNISTTrain = new MnistDataSetIterator(batchSize,true,seed);
DataSetIterator MNISTTest = new MnistDataSetIterator(batchSize,false,seed);
Let double check if we get ten unique labels from the dataset:
log.info("The number of total labels found in the training dataset " + MNISTTrain.totalOutcomes());
log.info("The number of total labels found in the test dataset " + MNISTTest.totalOutcomes());
Next, let’s configure the architecture of the model. We will use two convolution layers plus a flattened layer for the output. Deeplearning4j has several options that you can use to initialize the weight scheme.
// Building the CNN model
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed) // random seed
.l2(0.0005) // regularization
.weightInit(WeightInit.XAVIER) // initialization of the weight scheme
.updater(new Adam(1e-3)) // Setting the optimization algorithm
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
//Setting the stride, the kernel size, and the activation function.
.nIn(nChannels)
.stride(1,1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
// Setting the stride, kernel size, and the activation function.
.stride(1,1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
// the final output layer is 28x28 with a depth of 1.
.setInputType(InputType.convolutionalFlat(28,28,1))
.build();
Once the architecture is set, we need to initialize the mode, set the training dataset, and trigger the model training.
MultiLayerNetwork model = new MultiLayerNetwork(conf);
// initialize the model weights.
model.init();
log.info("Step2: start training the model");
//Setting a listener every 10 iterations and evaluate on test set on every epoch
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(MNISTTest, 1, InvocationType.EPOCH_END));
// Training the model
model.fit(MNISTTrain, nEpochs);
During the training, the score listener will provide the confusion matrix of the classification accuracy. Let’s see the accuracy after ten epochs of training:
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
977 0 0 0 0 0 1 1 1 0 | 0 = 0
0 1131 0 1 0 1 2 0 0 0 | 1 = 1
1 2 1019 3 0 0 0 3 4 0 | 2 = 2
0 0 1 1004 0 1 0 1 3 0 | 3 = 3
0 0 0 0 977 0 2 0 1 2 | 4 = 4
1 0 0 9 0 879 1 0 1 1 | 5 = 5
4 2 0 0 1 1 949 0 1 0 | 6 = 6
0 4 2 1 1 0 0 1018 1 1 | 7 = 7
2 0 3 1 0 1 1 2 962 2 | 8 = 8
0 2 0 2 11 2 0 3 2 987 | 9 = 9
Mallet Example:
As mentioned earlier, Mallet is a powerful toolkit for natural language modeling. We will use a sample corpus provided by the tool David Blei in Mallet package. Mallet has a specific library for annotating textual tokens for classification. Before we load our dataset, Mallet has this concept of pipeline definition where you define your pipeline and then provide the dataset to pass through.
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
The pipeline is defined as an “ArrayList,” which would contain typical steps that we always do before building a topic model. Each text in the document would pass the following steps:
- Lowercase keywords
- Tokenize text
- Remove stopwords
- Map to features
pipeList.add( new CharSequenceLowercase() );
pipeList.add( new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")) );
// Setting the dictionary of the stop words
URL stopWordsFile = getClass().getClassLoader().getResource("stoplists/en.txt");
pipeList.add( new TokenSequenceRemoveStopwords(new File(stopWordsFile.toURI()), "UTF-8", false, false, false) );
pipeList.add( new TokenSequence2FeatureSequence() );
Once the pipeline is defined, we will pass the instances representing an original text of each document.
InstanceList instances = new InstanceList (new SerialPipes(pipeList));
Now the step comes to pass the input file to fill up the instance list.
URL inputFileURL = getClass().getClassLoader().getResource(inputFile);
Reader fileReader = new InputStreamReader(new FileInputStream(new File(inputFileURL.toURI())), "UTF-8");
instances.addThruPipe(new CsvIterator (fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
3, 2, 1)); // data, label, name fields
From the last command line, you can notice that we provided instructions on how the CSV file is structured. The source file, available in the resources folder, has around two thousand rows. Each line represents an original document text and consists of three attributes separated by comma (Name, label, and document content). We can print the number of instances found in the input document using the following command:
log.info("The number of instances found in the input file is: " + instances.size());
Now, let’s model the document’s topics. Let’s assume that we have 100 different topics in those 2k documents. Mallet enables us to set two variables: alpha and beta weights. Alpha controls the topic-word distributions’ concentration, and beta represents the pre-word weights over the topic-word distributions.
int numTopics = 100;// defining the model ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01); // adding the instances to the model model.addInstances(instances);
The model we choose in this example is an implementation of LDA (Latent Dirichlet allocation). The algorithm uses a group of observed keywords similarity to classify documents.
One of the things I like about Mallet is the API capabilities to design your parallel processing easily. Here, we can define multithread processing for each subsample.
model.setNumThreads(2);
We only two things left now is defining the number of iterations for the model training and get the training started.
model.setNumIterations(50);
model.estimate();
I left more details on how to display the topic modeling result in the full example on github.
This article has been published from the source link without modifications to the text. Only the headline has been changed.