Flight School

Training a Text Classifier
with Create ML and the
Natural Language Framework

From Zero to Core ML Model in Minutes

Machine Learning can be difficult to get your head around as a programmer. But aside from all the advanced mathematics and tooling, perhaps most difficult of all is learning how to let go.

Many of us have spent years honing our craft of writing code: expressive, type-safe, unit-tested, refactored, clean, DRY code. So it’s tough to hear that pretty much anyone with enough data (and patience) can use machine learning to solve hard problems without understanding why or how it works.

But this isn’t news to you. In fact, ML is at the top of your list of things to learn next! You have Andrew Ng’s course bookmarked in Safari and a load of unread PDFs littering your Downloads folder. All you need is a free weekend and…

If that strikes a bit close to home for you then you’ll love CreateML.

Create ML is a new framework that makes it easy to train machine learning models. How easy? Drag pictures of dogs into a “Dogs” folder and pictures of cats into a “Cats” folder, write a few lines of Swift code, wait a couple minutes to train the model and boom: you have a image classifier that can tell you whether a picture contains a cat or a dog. You can even do this for text classification or regression for data tables.

If you haven’t already, go ahead and watch the Introducing Create ML session from this year’s WWDC. It’s like magic.

Pulling files from labeled directories makes for a nice demo, but what if your data set isn’t so nicely organized? This article shows how you can use Create ML to train a text classifier that predicts the programming language of unknown source code by manually creating the corpus from a heterogeneous data set.

You can find the complete training script and a demo playground here.


Acquiring a Corpus of Data

First things first: we need some examples of source code.

If you’re anything like the author, you might have thought to use GitHub search to find code by language. And if so, you’d eventually realize that the both the GitHub API (both the REST and GraphQL versions) requires code search results to be scoped by user or project. At that point, you’d probably wonder why you decided you needed to make this yourself before finding something off-the-shelf, like this project by Source Foundry.

Our corpus includes labeled code samples in C, C++, Go, Java, JavaScript, Objective-C, PHP, Python, Ruby, Rust, and Swift. Each directory in the project root corresponds to a language and contains flattened checkouts of a handful of popular open source repositories for that language:

$ tree -L 2 code-corpora/swift
code-corpora/swift
├── alamofire
│   ├── Alamofire.h
│   ├── Alamofire.swift
│   ├── AppDelegate.swift
│   ├── AuthenticationTests.swift
# ...

Notice that Objective-C header in the Swift project, though. We can’t rely entirely on the top-level directories as labels for their contents, because most projects include other auxillary scripts and source files (as well as README, LICENSE, and other repository miscellany).

Our training script uses the containing directory and file extension to determine the correct label for each file in our corpus. For example, .h files in the c directory are labeled as C, .h files in the cc directory are labeled as C++, and any file with the .go extension is labeled as Go:

switch (directory, fileExtension) {
    case ("c", "h"), (_, "c"): label = "C"
    case ("cc", "h"), (_, "cc"), (_, "cpp"): label = "C++"
    case (_, "go"): label = "Go"
    // ...
    default:
        // Unknown, skip
}

Training the Model

To build our data table, we recursively enumerate the contents of the corpus directory and append the contents of each source file that we can identify:

var corpus: [(text: String, label: String)] = []

let enumerator = FileManager.default.enumerator(
                    at: corpusURL,
                    includingPropertiesForKeys: [.isDirectoryKey]
                 )!

for case let resource as URL in enumerator {
    guard !resource.hasDirectoryPath,
        let language = ProgrammingLanguage(for: resource,
                                           at: enumerator.level),
        let text = try? String(contentsOf: resource)
    else {
        continue
    }
    corpus.append((text: text, label: language.rawValue))
}

let (texts, labels): ([String], [String]) =
    corpus.reduce(into: ([], [])) { (columns, row)
        columns.0.append(row.text)
        columns.1.append(row.label)
    }

let dataTable =
    try MLDataTable(dictionary: ["text": texts, "label": labels])

Our original implementation appended MLDataTable objects, instead of initializing a single data table from an accumulated array. We found this to have nonlinear performance characteristics, which caused training to take closer to an hour instead of a few minutes.

With our data table in hand, we use the randomSplit(by:seed:) method to segment our training and testing data. The former is used immediately, passed into the MLTextClassifier initializer; the latter will be used next to evaluate the model.

let (trainingData, testingData) =
    dataTable.randomSplit(by: 0.8, seed: 0)

let classifier = try MLTextClassifier(trainingData: trainingData,
                                      textColumn: "text",
                                      labelColumn: "label")

Creating an MLTextClassifier object takes a while, but you can track the progress by tailing STDOUT:

Automatically generating validation set from 5% of the data.
Tokenizing data and extracting features
10% complete
20% complete
30% complete
40% complete
50% complete
60% complete
70% complete
80% complete
90% complete
100% complete
Starting MaxEnt training with 8584 samples
Iteration 1 training accuracy 0.285182
Iteration 2 training accuracy 0.946295
Iteration 3 training accuracy 0.988001
Iteration 4 training accuracy 0.997554
Iteration 5 training accuracy 0.998602
Iteration 6 training accuracy 0.999185
Iteration 7 training accuracy 0.999651
Iteration 8 training accuracy 0.999767
Finished MaxEnt training in 7.12 seconds

The resulting model seems large for what it can do, weighing in at 3MB. However, it’s able to classify a file in ~20ms, which should be fast enough for most use cases.

Evaluating the Model

Let’s see how our classifier performs by calling the evaluation(on:) method and passing the testingData that we segmented before.

let evaluation = classifier.evaluation(on: testingData)
print(evaluation)

Accuracy

At the top of our evaluation, we get a summary with the number of examples, the number of classes, and the accuracy:

Number of Examples 1138
Number of Classes 10
Accuracy 99.56%

99.56% accuracy. That’s good, right? Let’s dig into the numbers to get a better understanding of how this behaves.


When you print(_:) an MLClassifierMetrics object, it shows a summary of the overall accuracy as well as a confusion matrix and a precision / recall table.

Confusion Matrix

A confusion matrix is a tool for visualizing the accuracy of predictions. Each column shows the predicted classes, and each row shows the actual class:

  C C++ Go Java JS Obj-C PHP Ruby Rust Swift
C 122 0 0 0 0 0 0 0 0 0
C++ 0 73 0 0 0 2 0 0 0 0
Go 1 0 333 0 0 0 0 0 0 0
Java 0 0 0 137 0 0 0 0 0 0
JS 0 0 0 0 55 0 0 0 0 0
Obj-C 0 0 0 0 0 97 0 0 0 0
PHP 0 0 0 0 0 0 95 0 0 0
Ruby 0 0 0 0 0 0 0 136 0 0
Rust 0 0 0 0 0 0 0 0 73 0
Swift 0 0 0 0 0 0 0 0 0 12

100% accuracy would have values along the diagonal line where the predicted and actual classes match, and zeroes everywhere else. However, our accuracy isn’t perfect, so we have a few stray figures. From the table, we can see that Go was mistaken for C once and C++ was incorrectly labeled as Objective-C twice.

Precision and Recall

Another way of analyzing our results is in terms of precision and recall.

Class Precision(%) Recall(%)
C 99.19 100.00
C++ 100.00 97.33
Go 100.00 98.94
Java 100.00 100.00
JavaScript 100.00 100.00
Objective-C 98.91 100.00
PHP 100.00 100.00
Ruby 100.00 100.00
Rust 100.00 100.00
Swift 100.00 100.00

Precision measures the ability of the model to identify only the relevant classification within a data set. For example, our model had perfect precision for C++ because it never misidentified any source files as being C++, however it has imperfect precision for C because it incorrectly identified a Go file as being C.

Recall measures the ability of a model to identify all of the relevant classifications within a data set. For example, our model had perfect recall for C because it correctly identified all of the C source code in the training data, and imperfect recall for C++ because it missed two C++ files in the training data.

Writing the Model to Disk

So, we have our classifier, we’ve evaluated it and found it to be satisfactory. The only thing left to do is to is write it to disk:

let modelPath = URL(fileURLWithPath: destinationPath)
let metadata = MLModelMetadata(
    author: "Mattt",
    shortDescription: "A model trained to classify programming languages",
    version: "1.0"
)
try classifier.write(to: modelPath, metadata: metadata)

All told, training, evaluating, and writing the model took less than 5 minutes:

$ time swift ./Trainer.swift
281.84 real       275.51 user         5.60 sys

Testing Out the Model in a Playground

In order to use our model from a Playground, we need to compile it first. For an iOS or Mac app, Xcode would automatically generate a programmatic interface for us. However, in a Playground, we need to do this ourselves.

Call the coremlc tool using the xcrun command, specifying the compile action on the .mlmodel file and target the current directory for the output:

$ xcrun coremlc compile ProgrammingLanguageClassifier.mlmodel .

Take the resulting .mlmodelc bundle (it’ll look like a normal directory in Finder) and move it into the Resources folder of your playground. You can use this to initialize a Natural Language framework NLModel to classify text using the predictedLabel(for:) method:

let url = Bundle.main.url(
    forResource: "ProgrammingLanguageClassifier",
    withExtension: "mlmodelc"
)!
let model = try! NLModel(contentsOf: url)

Now you can call the predictedLabel(for:) method to predict the programming language of a string containing code:

let code = """
struct Plane: Codable {
    var manufacturer: String
    var model: String
    var seats: Int
}
"""

model.predictedLabel(for: code) // Swift

The sample code project for this post wraps this up with a fun drag-and-drop UI, so you can easily test out your model with whatever source files you have littering your Desktop.

Screenshot of Classifier Example

Conclusion

There’s no way this actually works… right? Source code isn’t like other kinds of text, and weighing keywords and punctuation equally with comments and variable names is obviously a flawed approach.

The way classifiers work, our model may well be fixating on irrelevant details like license comments in the file header. Heck, that 99% accuracy we saw could be more a reflection of file similarity within the same project than of the model’s actual predictive ability.

All of that said, it might just be good enough.

Consider this: in under an hour, we went from nothing to a working solution without any significant programming. That’s pretty incredible.

Create ML is a powerful way to prototype new features quickly. If a minimum-viable product is good enough, then your job is done. Or if you need to go even further, there are all kinds of optimizations to be had in terms of model size, accuracy, and precision by using something like TensorFlow or Turi Create.