Background Removal in Real-Time Video Chats using TensorflowJS, Part 1
An app that removes and replaces in real-time the background in webcam video streams, and all from within the browser! No need for a green screen or a uniform background. This project was made during my 4 weeks at the AI Program of Insight Data Science (Palo Alto).
Try it here!
There is a trend in AI to move from Centralized Cloud Computing to Edge Computing [1], in particular for real time services application for which Centralized Cloud Computing suffers from higher latency. Furthermore, Edge Computing AI might provide solutions for privacy conscientious consumers [2]. One tool that is likely to help this trend is TensorflowJS (TFJS), in brief Tensorflow in Javascript wrapper. TFJS enables to create AI apps, which training and prediction can be conducted on the client side [3], [4].
In this demo, I limit the use case to a web-conference type of video stream with a single person in front of the webcam. At the frame level, that is basically a portrait. The goal is to perform a semantic segmentation on each frame where the person is labeled class 1 and the background class 0. The model is trained on a computer with a NVidia GPU, and the inference is performed on the browser of various devices.
Dataset
I used the Flickr portrait mask dataset (Shen et al. 2016): 1900 portrait images and their corresponding masks. The dataset has a variety of portrait’ backgrounds (indoor, outdoor, uniform), a variety of subjects/persons and of facial expressions. The average of all masks shows that in most portraits, the person stands in the center of the frame (see below).
The data preparation is relatively basic: cropping and resizing the original image and the masks.
The model
The model used is a U-Net type of architecture with the separable convolutional blocks of MobileNet. It is inspired from the MobileUnet implementation. The input is a RGB image (224 x 224) and the output is a binary image of same size. For this 1st implementation, I replaced the BiLinearUpsampling2D, a custom layer in MobileUnet, by UpSampling2D which is supported by TFJS. The model has a small footprint: 6 million parameters and ~28Mb.
The loss function is a log-loss, which is good enough for this first implementation.
Results
The training takes a bit less than 2hrs. The model performs relatively well on the test data (from the same dataset). The use of the UpSampling2D layer in place of the BiLinearUpsampling2D layer results in pixelation at the mask edges.
On my MacBookPro, the inference on a single frame takes a bit more than 300ms, which suggests a frame rate of ~3fps. That’s quite small, considering that a smooth video requires at least 15–20 fps.
The trained model (weight + architecture) is then converted to TFJS format and uploaded in a bucket of Google Cloud Storage.
TensorflowJS to run inference in the browser
At first, the model is downloaded on the client-side. The video from the webcam is captured and rendered on the html page. Each frames are converted to a tensor and are fed to the model: a binary tensor is output. With a few tensor manipulations, the background can be changed to a set of pre-selected image backgrounds (the drop-down menu). The inference is conducted on a frame-by-frame basis.
The segmentation works relatively well even on images that the model has never seen. Even though the model was trained on static images without temporal correlations, the results on the live video stream are pretty promising.
The frame rate is obviously device dependent:
- Desktop with GPU TitanX: 26–28fps
- Macbook Pro with Radeon Pro GPU: ~3fps on average.
- Samsung Galaxy J8 cellphone: <<1fps.
Note that the frame with the hidden background (right) is a bit darker than the original image (left): that’s due to the the crude method I used to superpose the synthetic background, the mask and the original image. I am working on improving that piece, and that will be in Part 2 of the story.
Try a live demo here, and let me know what you think!
…In Part2 (coming soon) I will be discussing some improvements made to the model to increase the frame rate on the macBook Pro, and data augmentation schemes to improve generalization. Stay tuned!…