Training a Text Classifier
with Create ML and the
Natural Language Framework
From Zero to Core ML Model in Minutes
Machine Learning can be difficult to get your head around as a programmer. But aside from all the advanced mathematics and tooling, perhaps most difficult of all is learning how to let go.
Many of us have spent years honing our craft of writing code: expressive, type-safe, unit-tested, refactored, clean, DRY code. So it’s tough to hear that pretty much anyone with enough data (and patience) can use machine learning to solve hard problems without understanding why or how it works.
But this isn’t news to you. In fact, ML is at the top of your list of things to learn next! You have Andrew Ng’s course bookmarked in Safari and a load of unread PDFs littering your Downloads folder. All you need is a free weekend and…
If that strikes a bit close to home for you then you’ll love CreateML.
Create ML is a new framework that makes it easy to train machine learning models. How easy? Drag pictures of dogs into a “Dogs” folder and pictures of cats into a “Cats” folder, write a few lines of Swift code, wait a couple minutes to train the model and boom: you have a image classifier that can tell you whether a picture contains a cat or a dog. You can even do this for text classification or regression for data tables.
If you haven’t already, go ahead and watch the Introducing Create ML session from this year’s WWDC. It’s like magic.
Pulling files from labeled directories makes for a nice demo, but what if your data set isn’t so nicely organized? This article shows how you can use Create ML to train a text classifier that predicts the programming language of unknown source code by manually creating the corpus from a heterogeneous data set.
You can find the complete training script and a demo playground here.
#Acquiring a Corpus of Data
First things first: we need some examples of source code.
If you’re anything like the author, you might have thought to use GitHub search to find code by language. And if so, you’d eventually realize that the both the GitHub API (both the REST and GraphQL versions) requires code search results to be scoped by user or project. At that point, you’d probably wonder why you decided you needed to make this yourself before finding something off-the-shelf, like this project by Source Foundry.
Our corpus includes labeled code samples in C, C++, Go, Java, JavaScript, Objective-C, PHP, Python, Ruby, Rust, and Swift. Each directory in the project root corresponds to a language and contains flattened checkouts of a handful of popular open source repositories for that language:
$ tree -L 2 code-corpora/swift
code-corpora/swift
├── alamofire
│ ├── Alamofire.h
│ ├── Alamofire.swift
│ ├── App Delegate.swift
│ ├── Authentication Tests.swift
# ...
Notice that Objective-C header in the Swift project, though. We can’t rely entirely on the top-level directories as labels for their contents, because most projects include other auxillary scripts and source files (as well as README, LICENSE, and other repository miscellany).
Our training script uses the containing directory and file extension
to determine the correct label for each file in our corpus.
For example,
.h
files in the c
directory are labeled as C,
.h
files in the cc
directory are labeled as C++,
and any file with the .go
extension is labeled as Go:
switch (directory, file Extension) {
case ("c", "h"), (_, "c"): label = "C"
case ("cc", "h"), (_, "cc"), (_, "cpp"): label = "C++"
case (_, "go"): label = "Go"
// ...
default:
// Unknown, skip
}
#Training the Model
To build our data table, we recursively enumerate the contents of the corpus directory and append the contents of each source file that we can identify:
var corpus: [(text: String, label: String)] = []
let enumerator = File Manager.default.enumerator(
at: corpus URL,
including Properties For Keys: [.is Directory Key]
)!
for case let resource as URL in enumerator {
guard !resource.has Directory Path,
let language = Programming Language(for: resource,
at: enumerator.level),
let text = try? String(contents Of: resource)
else {
continue
}
corpus.append((text: text, label: language.raw Value))
}
let (texts, labels): ([String], [String]) =
corpus.reduce(into: ([], [])) { (columns, row)
columns.0.append(row.text)
columns.1.append(row.label)
}
let data Table =
try MLData Table(dictionary: ["text": texts, "label": labels])
Our original implementation appended
MLData
objects, instead of initializing a single data table from an accumulated array. We found this to have nonlinear performance characteristics, which caused training to take closer to an hour instead of a few minutes.Table
With our data table in hand,
we use the random
method
to segment our training and testing data.
The former is used immediately,
passed into the MLText
initializer;
the latter will be used next to evaluate the model.
let (training Data, testing Data) =
data Table.random Split(by: 0.8, seed: 0)
let classifier = try MLText Classifier(training Data: training Data,
text Column: "text",
label Column: "label")
Creating an MLText
object takes a while,
but you can track the progress by tailing STDOUT:
Automatically generating validation set from 5% of the data.
Tokenizing data and extracting features
10% complete
20% complete
30% complete
40% complete
50% complete
60% complete
70% complete
80% complete
90% complete
100% complete
Starting Max Ent training with 8584 samples
Iteration 1 training accuracy 0.285182
Iteration 2 training accuracy 0.946295
Iteration 3 training accuracy 0.988001
Iteration 4 training accuracy 0.997554
Iteration 5 training accuracy 0.998602
Iteration 6 training accuracy 0.999185
Iteration 7 training accuracy 0.999651
Iteration 8 training accuracy 0.999767
Finished Max Ent training in 7.12 seconds
The resulting model seems large for what it can do, weighing in at 3MB. However, it’s able to classify a file in ~20ms, which should be fast enough for most use cases.
#Evaluating the Model
Let’s see how our classifier performs
by calling the evaluation(on:)
method
and passing the testing
that we segmented before.
let evaluation = classifier.evaluation(on: testing Data)
print(evaluation)
#Accuracy
At the top of our evaluation, we get a summary with the number of examples, the number of classes, and the accuracy:
Number of Examples | 1138 |
Number of Classes | 10 |
Accuracy | 99.56% |
99.56% accuracy. That’s good, right? Let’s dig into the numbers to get a better understanding of how this behaves.
When you print(_:)
an MLClassifier
object,
it shows a summary of the overall accuracy
as well as a confusion matrix and a precision / recall table.
#Confusion Matrix
A confusion matrix is a tool for visualizing the accuracy of predictions. Each column shows the predicted classes, and each row shows the actual class:
C | C++ | Go | Java | JS | Obj-C | PHP | Ruby | Rust | Swift | |
---|---|---|---|---|---|---|---|---|---|---|
C | 122 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
C++ | 0 | 73 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 |
Go | 1 | 0 | 333 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
Java | 0 | 0 | 0 | 137 | 0 | 0 | 0 | 0 | 0 | 0 |
JS | 0 | 0 | 0 | 0 | 55 | 0 | 0 | 0 | 0 | 0 |
Obj-C | 0 | 0 | 0 | 0 | 0 | 97 | 0 | 0 | 0 | 0 |
PHP | 0 | 0 | 0 | 0 | 0 | 0 | 95 | 0 | 0 | 0 |
Ruby | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 136 | 0 | 0 |
Rust | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 73 | 0 |
Swift | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 12 |
100% accuracy would have values along the diagonal line where the predicted and actual classes match, and zeroes everywhere else. However, our accuracy isn’t perfect, so we have a few stray figures. From the table, we can see that Go was mistaken for C once and C++ was incorrectly labeled as Objective-C twice.
#Precision and Recall
Another way of analyzing our results is in terms of precision and recall.
Class | Precision(%) | Recall(%) |
---|---|---|
C | 99.19 | 100.00 |
C++ | 100.00 | 97.33 |
Go | 100.00 | 98.94 |
Java | 100.00 | 100.00 |
JavaScript | 100.00 | 100.00 |
Objective-C | 98.91 | 100.00 |
PHP | 100.00 | 100.00 |
Ruby | 100.00 | 100.00 |
Rust | 100.00 | 100.00 |
Swift | 100.00 | 100.00 |
Precision measures the ability of the model to identify only the relevant classification within a data set. For example, our model had perfect precision for C++ because it never misidentified any source files as being C++, however it has imperfect precision for C because it incorrectly identified a Go file as being C.
Recall measures the ability of a model to identify all of the relevant classifications within a data set. For example, our model had perfect recall for C because it correctly identified all of the C source code in the training data, and imperfect recall for C++ because it missed two C++ files in the training data.
#Writing the Model to Disk
So, we have our classifier, we’ve evaluated it and found it to be satisfactory. The only thing left to do is to is write it to disk:
let model Path = URL(file URLWith Path: destination Path)
let metadata = MLModel Metadata(
author: "Mattt",
short Description: "A model trained to classify programming languages",
version: "1.0"
)
try classifier.write(to: model Path, metadata: metadata)
All told, training, evaluating, and writing the model took less than 5 minutes:
$ time swift ./Trainer.swift
281.84 real 275.51 user 5.60 sys
#Testing Out the Model in a Playground
In order to use our model from a Playground, we need to compile it first. For an iOS or Mac app, Xcode would automatically generate a programmatic interface for us. However, in a Playground, we need to do this ourselves.
Call the coremlc
tool using the xcrun
command,
specifying the compile
action on the .mlmodel
file
and target the current directory for the output:
$ xcrun coremlc compile Programming Language Classifier.mlmodel .
Take the resulting .mlmodelc
bundle
(it’ll look like a normal directory in Finder)
and move it into the Resources folder of your playground.
You can use this to initialize a Natural Language framework NLModel
to classify text using the predicted
method:
let url = Bundle.main.url(
for Resource: "Programming Language Classifier",
with Extension: "mlmodelc"
)!
let model = try! NLModel(contents Of: url)
Now you can call the predicted
method
to predict the programming language of a string containing code:
let code = """
struct Plane: Codable {
var manufacturer: String
var model: String
var seats: Int
}
"""
model.predicted Label(for: code) // Swift
The sample code project for this post wraps this up with a fun drag-and-drop UI, so you can easily test out your model with whatever source files you have littering your Desktop.
#Conclusion
There’s no way this actually works… right? Source code isn’t like other kinds of text, and weighing keywords and punctuation equally with comments and variable names is obviously a flawed approach.
The way classifiers work, our model may well be fixating on irrelevant details like license comments in the file header. Heck, that 99% accuracy we saw could be more a reflection of file similarity within the same project than of the model’s actual predictive ability.
All of that said, it might just be good enough.
Consider this: in under an hour, we went from nothing to a working solution without any significant programming. That’s pretty incredible.
Create ML is a powerful way to prototype new features quickly. If a minimum-viable product is good enough, then your job is done. Or if you need to go even further, there are all kinds of optimizations to be had in terms of model size, accuracy, and precision by using something like TensorFlow or Turi Create.