Demystifying TFX Standard Components


When we think about ML we tend to focus on the model training part. But when we move to production we realize that there many other pieces that are very important for the model to be available and robust over its lifetime.

A production solution requires so much more to be able to deal with all issues that we face during ML development:


TFX is a flexible ML platform that lets users build ML pipelines, using different orchestrators and different underlying execution engines. It also implements some best practices to standardize ML models lifecycle management:


Conceptually, TFX has a layered architecture to coordinate the execution of its components. The layers are:


Metadata Store

At the heart of TFX is the Metadata Store which is responsible for containing:


A TFX component is responsible for performing a specific task, for instance, data ingestion, model training with TensorFlow, or serving with TF Serving. Every component in TFX has three building blocks:


When TFX components are connected they form a pipeline through which data will flow, e.g. from ingestion data to serving models. The communication happens over the metadata store, each component read its dependencies from it and write back its output/artifact.

Standard Components

There is a set of standard components which are shipped with TFX, which we can build/extend upon them in a couple of different ways.

TFX-Canonical Pipeline

At the left we ingest data, we flow through, calculate some statistics about it, then we make sure there is no problem with the data, understand what type of feature we have, do feature engineering, we train, check the metrics, and then the question should I push this new model to production (if the new model outperforms existent one). Along with that we also have the ability to do bulk inference.


This component takes raw data as input and generates TensorFlow examples, it can take many input formats (e.g. CSV, TF Record). It also does split the examples for you into Train/Eval. It then passes the result to the StatisticsGen component.

examples = csv_input(os.path.join(data_root, 'simple'))
example_gen = CsvExampleGen(input_base=examples)


StatisticsGen generates useful statistics that help diving into the data and understanding its characteristics. It also comes with visualization tools.

For instance, in the following example the column trip_start_hour seems to have a time window between 5 am and 6 am where data is missing. Such a histogram helps determine the area we need to focus on to fix any data-related problems. In this we need to get more data, otherwise, the inference for 6 am data will be overgeneralized.


statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)


SchemaGen is looking at the data type of the input, is it an int, float, categorical, etc. If it is categorical then what are the valid values? It also comes with a visualization tool to review the inferred schema and fix any issues.

infer_schema = SchemaGen(stats=statistics_gen.outputs.output)


ExampleValidator takes the inputs and looks for problems in the data (missing values, 0 values that should not be 0) and report any anomalies.

validate_stats = ExampleValidator(


Transform takes data generated by the ExampleGen component and the schema generated by the SchemaGen to implement arbitrary complex logic, depending on the need of the dataset and model, e.g. to perform features engineering.

Note that the logic within this component cannot be eagerly executed as it will be turned into a graph that will be prepended to the model. This means that we will be doing the same feature engineering with the same code during both training and production which eliminates the training-serving skew.

transform = Transform(
# do some transformation
  outputs[_transformed_name(key)] = transform.scale_to_z_score(_fill_in_missing(inputs[key]))
# ...
outputs[_transformed_name(_LABEL_KEY)] = tf.where(
  tf.cast(tf.zeros_like(taxi_fare), tf.int64),
  # Test if the tip was > 20% of the fare
  tf.cast(tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)
# ...


Trainer performs the training of the model. It uses TensorBoard to log performance metrics which helps to understand the training process and comparing execution runs.

trainer = Trainer(module_file=taxi_module_file,
  train_steps=10000, eval_steps=5000, warm_starting=True


Evaluator is a tool that lets us not only looking at top-level metrics (RMSE, AUC) but also looking at individual slices of the dataset and slices of features within the dataset. Things like Fairness becomes very manageable with this component.

model_analyzer = Evaluator(


This component helps to compare the different version of a model, e.g. a production model against a new model which is in current development using different validation modes:

model_validator = ModelValidator(


This component is responsible for pushing the trained (and validated) model different deployment options:

It can be configured to block deployment on outcome of model validation.

pusher = Pusher(


BulkInferrer performs offline batch inference over inference examples. It outputs the features and predictions of the model.

It can be configured to block the inference on a model validation outcome. AlsoL

bulk_inferrer = BulkInferrer(

TFX Pipeline

The previous standard components can be used together to create a pipeline. The following code snippet illustrates how to create a TFX pipeline:

Here is a concrete example:

def _create_pipeline():
  """Implements a TFX pipeline."""
  csv_data = csv_input(os.path.join(data_root, 'simple'))
  example_gen = CsvExampleGen(input=csv_data)

  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
  infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
  validate_stats = ExampleValidator(statistics=statistics_gen.output['statistics'], schema=infer_schema.outputs['schema'])

  # Performs feature engineering
  transform = Transform(examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'], module_file=_taxi_module_file)

  trainer = Trainer(...)
  model_analyzer = Evaluator(examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
  model_validator = ModelValidator(examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
  pusher = Pusher(model=..., model_blessing=..., serving_model_dir=...)

  return [example_gen, statistics_gen, infer_schema, validate_stats, transform, trainer, model_analyzer, model_validator, pusher]

result = AirflowDAGRunner(_airflow_config).run(_create_pipeline())