KotlinDL 0.2: Functional API, Model Zoo With ResNet and MobileNet, Idiomatic Kotlin DSL for Image Preprocessing, and Many New Layers

KotlinDL 0.2: Functional API, Model Zoo With ResNet and MobileNet, Idiomatic Kotlin DSL for Image Preprocessing, and Many New Layers

Introducing version 0.2 of our deep learning library, KotlinDL.

KotlinDL 0.2 is available now on Maven Central with a variety of new features – check out all the changes coming to the new release! New layers, a special Kotlin-idiomatic DSL for image preprocessing, a few types of Datasets, a great Model Zoo with support for the ResNet and MobileNet model families, and many more changes are now receiving a final polish.

KotlinDL on GitHub

In this post, we’ll walk you through the changes to the Kotlin Deep Learning library in the 0.2 release:

Functional API

With the previous version of the library, you could only use the Sequential API to describe your model. Using the

Sequential.of(..)

method call, it has been possible to build a sequence of layers to describe models in a style similar to VGG.

Since 2014, many new architectures have addressed the disadvantages inherent in simple layer sequences, such as vanishing gradients or the degradation (accuracy saturation) problem. The famous residual neural networks (ResNet) use skip connections or shortcuts to jump over layers. In version 0.2, we’ve added a new Functional API that makes it possible for you to build models such as ResNet or MobileNet.

The Functional API provides a way to create models that are more flexible than the Sequential API. The Functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.

The main idea behind it is that a deep learning model is usually a directed acyclic graph (DAG) of layers. So the Functional API is a way to build graphs of layers.

Let’s build a ToyResNet model for the FashionMnist dataset to demonstrate this:

val (train, test) = fashionMnist()

val inputs = Input(28, 28, 1)
val conv1 = Conv2D(32)(inputs)
val conv2 = Conv2D(64)(conv1)
val maxPool = MaxPool2D(poolSize = intArrayOf(1, 3, 3, 1), 
                        strides = intArrayOf(1, 3, 3, 1))(conv2)

val conv3 = Conv2D(64)(maxPool)
val conv4 = Conv2D(64)(conv3)
val add1 = Add()(conv4, maxPool)

val conv5 = Conv2D(64)(add1)
val conv6 = Conv2D(64)(conv5)
val add2 = Add()(conv6, add1)

val conv7 = Conv2D(64)(add2)
val globalAvgPool2D = GlobalAvgPool2D()(conv7)
val dense1 = Dense(256)(globalAvgPool2D)
val outputs = Dense(10, activation = Activations.Linear)(dense1)

val model = Functional.fromOutput(outputs)

model.use {
   it.compile(
       optimizer = Adam(),
       loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
       metric = Metrics.ACCURACY
   )

   it.summary()

   it.fit(dataset = train, epochs = 3, batchSize = 1000)

   val accuracy = it.evaluate(dataset = test, batchSize = 1000)
                    .metrics[Metrics.ACCURACY]

   println("Accuracy after: $accuracy")
}

Here’s a summary of the model:

And here is a representation of the  model architecture typical of the whole ResNet model family:

The main design of this API is borrowed from the Keras library, but it is not a complete copy. If you find any discrepancies between our API and the Keras library API, please refer to our documentation.

Model Zoo: ResNet and MobileNet models families support

Starting with the 0.2 release, Kotlin DL will include a Model Zoo, a collection of deep convolutional networks pre-trained on a large image dataset known as ImageNet.

The Model Zoo is important because modern architectures of convolutional neural networks can have hundreds of layers and tens of millions of parameters. Training models to an acceptable accuracy level (~70-80%) on ImageNet may require hundreds or thousands of hours of computation on a cluster of GPUs. With the Model Zoo, there is no need to train a model from scratch every time you need one. You can get a ready, pre-trained model from our repository of models and immediately use it for image recognition or transfer learning.

The following models are currently supported:

  • VGG’16
  • VGG’19
  • ResNet50
  • ResNet101
  • ResNet152
  • ResNet50v2
  • ResNet101v2
  • ResNet152v2
  • MobileNet
  • MobileNetv2

All the models in the Model Zoo include a special loader of model configs and model weights, as well as the special data preprocessing function that was applied when the models were trained on the ImageNet dataset.

Here’s an example of how to use one of these models, ResNet50, for prediction:

// specify the model type to be loaded, ResNet50, for example
val loader =
   ModelZoo(commonModelDirectory = File("cache/pretrainedModels"), modelType = ModelType.ResNet_50)

// obtain the model configuration
val model = loader.loadModel() as Functional

// load class labels (from ImageNet dataset in ResNet50 case)
val imageNetClassLabels = loader.loadClassLabels()

// load weights if required (for Transfer Learning purposes)
val hdfFile = loader.loadWeights()

Now, you’ve got a model and weights, and you can use it in KotlinDL.

NOTE: Don’t forget to apply model-specific preprocessing for the new data. All the preprocessing functions are included in the Model Zoo and can be called via the preprocessInput function.

If you want to train VGG or ResNet models from scratch, you can simply load the model configuration or start from the full model code written in Kotlin. All Model Zoo models are available via top-level functions located in the org.jetbrains.kotlinx.dl.api.core.model package.

val model = resnet50Light(imageSize = 28, 
                          numberOfClasses = 10, 
                          numberOfChannels = 1, 
                          lastLayerActivation = Activations.Linear)

A full example of how to use VGG’19 for prediction and transfer learning with additional training on a custom dataset can be found in this tutorial.

DSL for image preprocessing

Python developers have access to a huge number of utilities and libraries for data preprocessing. However, in JVM languages there are specific difficulties with preprocessing images, videos, and music. Most libraries for image preprocessing in Java and Kotlin use the BufferedImage class, whose methods are sometimes inconsistent and at a very low level of abstraction. We decided to simplify the lives of Kotlin developers by making an easy and straightforward DSL using lambdas with receivers for setting the image preprocessing pipeline.

The DSL for image preprocessing can use the following operations:

  • Load
  • Crop
  • Resize
  • Rotate
  • Rescale
  • Sharpen
  • Save
val preprocessing: Preprocessing = preprocess {
   transformImage {
       load {
           pathToData = imageDirectory
           imageShape = ImageShape(224, 224, 3)
           colorMode = ColorOrder.BGR
       }
       rotate {
           degrees = 30f
       }
       crop {
           left = 12
           right = 12
           top = 12
           bottom = 12
       }
       resize {
           outputWidth = 400
           outputHeight = 400
           interpolation = InterpolationType.NEAREST
       }
   }
   transformTensor {
       rescale {
           scalingCoefficient = 255f
       }
   }
}

As a result, basic augmentation can be implemented manually. Below you can find several images obtained by the preprocessing application mentioned above, changing the rotation angle and image size:

If you use additional image preprocessing steps, please feel free to make a feature request in our issue tracker.

New layers

We’ve implemented a variety of new layers that are required for ResNet and MobileNet models: 

  • BatchNorm
  • ActivationLayer
  • DepthwiseConv2D
  • SeparableConv2D
  • Merge (Add, Subtract, Multiply, Average, Concatenate, Maximum, Minimum)
  • GlobalAvgPool2D
  • Cropping2D
  • Reshape
  • ZeroPadding2D*

* Kudos to Anton Kosyakov for implementing ZeroPadding2D! 

If you would like to contribute a layer, we would be delighted to take a look at your pull requests.

Dataset API and its implementations: OnHeapDataset & OnFlyDataset

The standard way to run data through a neural network in forward mode is to load batches one by one into RAM and then into the memory area controlled by the TensorFlow computational graph.

We support this approach with an on-the-fly dataset (OnFlyDataset). It sequentially loads batch after batch into RAM during one training epoch, applying the preprocessing described in advance (if it was defined). 

But what if our data fits into RAM? You can use an OnHeapDataset to load and keep all this data in RAM without reading it repeatedly from the disk at each epoch.

Embedded datasets

For those of you who are just starting your journey in deep learning, we recommend practicing building your first neural networks on well-known datasets, such as a set of handwritten numbers (MNIST dataset), a similar set of images of fashion items from Zalando (FashionMNIST), the famous Cifar’10 dataset (50,000 images), or a collection of photos of cats and dogs from one of the most popular Kaggle competitions (25,000 images of various sizes).

All these datasets are stored remotely and, if necessary, can be downloaded to a folder on your disk. If the dataset has already been downloaded, it will be downloaded again and loaded immediately from the disk.

Adding KotlinDL to your project

To use KotlinDL in your project, you need to add the following dependency to your build.gradle file:

repositories {
    mavenCentral()
}

dependencies {
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-api:0.2.0'
}

You can also take advantage of Kotlin DL functionality in any existing Java project, even if you don’t have any other Kotlin code in it yet. Here is an example of the LeNet-5 model written completely in Java.

Learn more and give feedback

We hope you enjoyed this brief overview of the new features in KotlinDL version 0.2!

  • For more information see GitHub.
  • Check out the KotlinDL guide, which covers the library’s basic and advanced features.
  • Join the #kotlindl channel in Kotlin Slack (get an invite here)
  • If you have previously used KotlinDL, use the changelog for migration.
  • Check out this talk from Alexey Zinoviev, which offers a closer look at the library’s design, ideology, etc.
  • Issue tracker is here

Let’s Kotlin!