Scalable data architecture

TensorFlow Extended (TFX) for data validation in practice

Marketing
Vincent Lepage

In 2019, Google open-sourced TensorFlow Extended (TFX), a set of libraries to build production-scale Machine Learning. 4 years earlier, in 2015, when Google open-sourced TensorFlow, they already had the full project (TF and TFX) ready for their own use, but chose to publish only the core of the technology at the time.

At Sarus, we are always using and evaluating various Machine Learning Operations (MLOps) technologies, and we’d like to share some work we’ve done on TFX. Hope this helps!

TFX: the big picture

First of all, don’t think of TFX as a standalone product or tool. It’s a bunch of libs that provide essential components for production-grade MLOps.

What I love about TFX is that it makes you think: it provides a very sensible high-level architecture, and you can pick which components you actually need, and how you’re going to implement them. Most machine learning workflows fit easily in this modular approach.

TFX components (from https://www.tensorflow.org/tfx/guide)

The main sets of components are:

  • Data validation: components related to statistics, schema generation and example validation
  • Transform: tools to preprocess data
  • Training: Estimator and TensorFlow models
  • Model analysis: Evaluating models before deploying them — a critical step
  • Pushing and serving models

A focus on Data Validation

Data Validation components are available in the tensorflow_data_validation package.

Why data validation is important: a real-life anecdote

We all know that real-life data can be low-quality and full of surprises: missing values, measurement errors, poorly specified fields or non-stationarities. And, it is often required to retrain regularly one’s model, which uses new training data.

My previous startup was an AI-driven adtech company. We were ingesting billions of events per day and retraining our ML models daily based on those events. The algorithm computed the optimal selling price of our customer’s assets and a misstep could lead to significant financial losses. We pushed new models gradually and monitored model performance on a daily basis. It seemed very stable and robust.

However, one day, we began noticing lower performance. Revenue from some clients started to decline in some areas for no obvious reason. After investigating, we found out that our prediction models were behaving strangely, so we looked at the training datasets. OMG! We found a few events (out of billions) that showed extravagant values for the selling prices of some items — a billion times higher than average! Where were these values coming from?

It turned out that a developer, who was working on a supposedly unrelated reporting product, wanted to test it in production. It was easier for him to test by inputting extreme values. Unfortunately, it poisoned our datasets and drove our prediction engine crazy.

The issue was not that someone had injected outliers into the system. The issue was that we lacked proper data validation. We quickly implemented an outlier detection process to avoid such issues.

Because a model can be very complex and have internal layers, it’s usually harder to understand it’s behavior than to check the training data. Interpretability of models is usually tricky, so it’s better to implement validation processes to ensure data integrity.

Let’s code a bit

Let’s dive into the code. We’re going to build stats and validation schemas for the famous “census” dataset available on Kaggle. It’s based on US census data, and the usual exercise is to try to predict people’s annual income (there are only two classes: greater or lower than $50k).

Statistics

Data validation starts by computing statistics on the dataset. This is done by parsing and analyzing the types and distribution of data.

import tensorflow_data_validation as tfdv
csv_path = "./adult.csv"
train_stats = tfdv.generate_statistics_from_csv(
                   data_location=csv_path
                   )
tfdv.visualize_statistics(train_stats)

We get a neat, interactive representation:

Unfortunately these statistics are not differentially private and may leak sensitive information (think about a case where only one individual is in a category). Therefore, at Sarus, we could not use these statistics directly (we built a custom component to generate differentially private stats, which is basically our own StatisticsGen component). But we can still use the TFX statistics for schema inference .

Schema inference

A schema is a description of the types of your features. Typing is a powerful way to constrain the features in a way that can be easily defined and checked.

Let’s infer a schema from the stats we just computed.

schema = tfdv.infer_schema(statistics=train_stats)
tfdv.display_schema(schema=schema)

which gives us the following output:

A few comments:

  • once you have the statistics, inferring the Schema is very fast
  • Schema is an instance of a protobuf defined here
  • categories are typed as STRING within a specific domain. This is specified in the above-mentioned protobuf

Validating an example or a dataset against a schema

The TFX API gives you many ways to ask for validation, they all revolve around computing anomalies. Remember that TFX is production-oriented, so it favors working with batches (from PyArrow), TFRecords files or equivalent.

Here, we use a Pandas Dataframe

import pandas as pd
invalid_example = [{"age":"40", "income":"error"},{"age":"another_error", "income":">50K"}]
df = pd.DataFrame.from_records(invalid_example)
errors_stats = tfdv.generate_statistics_from_dataframe(df)
anomalies = tfdv.validate_statistics(statistics=errors_stats, schema=schema)
tfdv.display_anomalies(anomalies)

This is quite rich:

  • validation gives you missing columns
  • type mismatch (INT/STRING)
  • String outside of the specified domain
  • stats on # of errors

Based on these anomalies, we can decide to drop a batch, or if valid, ingest it into the production pipeline. We can also try to fix anomalies, by applying a transformation to the batch (converting types, capping values etc.) and re-checking afterwards.

Conclusion

Data validation is essential in many production ML pipelines to easily detect errors that would hurt performance. This is the ML equivalent of unit testing for software engineering. TFX gives you scalable, out-of-the box components to make it easier. It should suit most of your needs and I’d encourage you to give it a try! (and give us feedback in the comment section).

We’re hiring!

If you’re a developer, a data scientist or if you are just passionate about machine learning, privacy technology, and how to use them for good, please apply here.

Resources

  • The code and data source are available with an Apache 2 licence on our Github page.
  • Of course you can find a lot of useful info on TFX site.
  • an interesting video by Chris Fregly on TFX, Kubeflow and Differential Privacy

About the author

Vincent Lepage

Cofounder & CTO @ Sarus

Ready?

Ready to unlock the value of your data? We can set you up in no time.
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

Shell

Subscribe to our newsletter

You're on the list! Thank you for signing up.
Oops! Something went wrong while submitting the form.
32, rue Alexandre Dumas
75011 Paris — France
Resources
Blog
©2023 Sarus Technologies.
All rights reserved.