# !pip install -q -U tfx==1.4.0
# !pip install tensorflow==2.7.0
# !pip install tensorflow_decision_forests==0.2.1
# creating required directories
!mkdir -p saved_data
!mkdir -p model
!mkdir -p data
import tempfile
import tensorflow as tf
import urllib.request
import os
import pandas as pd
import shutil
import tensorflow_data_validation as tfdv
import tensorflow_model_analysis as tfma
import tensorflow_decision_forests as tfdf
from absl import logging
from pathlib import Path
from tfx import v1 as tfx
from tensorflow_metadata.proto.v0 import schema_pb2
from tfx.proto import example_gen_pb2
print(f"Tensorflow Version: {tf.__version__}")
print(f"TFX Version: {tfx.__version__}")
print(f"TFDF Version: {tfdf.__version__}")
print(f"Tensorflow Data Validation Version: {tfdv.__version__}")
logging.set_verbosity(logging.INFO)
DATA_DIRECTORY = 'data'
DATA_SOURCE_PATH = Path(DATA_DIRECTORY) / 'Social_Network_Ads.csv'
SAVED_DATA = 'saved_data'
DATA_TRAIN_FILENAME = Path(SAVED_DATA) / 'train.csv'
DATA_TEST_FILENAME = Path(SAVED_DATA) / 'test.csv'
PIPELINE_NAME = 'sample-pipeline'
PIPELINE_DIRECTORY = os.path.join(Path('pipelines'), PIPELINE_NAME)
METADATA_PATH = Path("metadata") / PIPELINE_NAME / "metadata.db"
SCHEMA_DIRECTORY = os.path.join(PIPELINE_DIRECTORY, 'schema')
SCHEMA_FILE_NAME = str(os.path.join(SCHEMA_DIRECTORY, 'schema.pbtxt'))
MODEL_DIRECTORY = Path('model')
# Module Paths
CONSTANTS_MODULE_PATH = 'constants.py'
TRANSFORM_MODULE_PATH = 'transform.py'
TRAINER_MODULE_PATH = 'trainer.py'
data_df = pd.read_csv(DATA_SOURCE_PATH)
# splitting the data for training and testing
data_df = data_df.sample(frac=1)
train_df = data_df[: int(len(data_df) * 0.7)]
test_df = data_df[int(len(data_df) * 0.7): ]
# removing the undesired columns from all the datasets
datasets = [train_df, test_df]
drop_columns = ['User ID']
for dataset in datasets:
dataset.drop(drop_columns, axis=1, inplace=True)
train_df.to_csv(DATA_TRAIN_FILENAME, index=False)
test_df.to_csv(DATA_TEST_FILENAME, index=False)
# peeking at the data
train_df.info()
# creating the useful conversion functions
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(str(value), encoding="raw_unicode_escape")]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# This function will be required to convert out test set
# to compatible schema types for inferencing
def _examples(df):
examples = []
for index, row in df.iterrows():
features = {
"Gender": _bytes_feature(row['Gender']),
"Age": _int64_feature(row['Age']),
"EstimatedSalary": _int64_feature(row['EstimatedSalary']),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(example_proto.SerializeToString())
return examples
from tfx.orchestration.experimental.interactive.interactive_context import (InteractiveContext)
context = InteractiveContext()
output = tfx.proto.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
tfx.proto.SplitConfig.Split(name="train", hash_buckets=3),
tfx.proto.SplitConfig.Split(name="eval", hash_buckets=1)
]))
example_gen = tfx.components.CsvExampleGen(input_base=SAVED_DATA, output_config=output)
context.run(example_gen)
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples']
)
context.run(statistics_gen)
context.show(statistics_gen.outputs['statistics'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True
)
context.run(schema_gen)
context.show(schema_gen.outputs['schema'])
schema = tfdv.load_schema_text(
os.path.join(schema_gen.outputs['schema']._artifacts[0].uri, "schema.pbtxt")
)
# adding the needed environments
schema.default_environment.append("TRAINING")
schema.default_environment.append("SERVING")
# removing the `Purchased` column from the serving environment
tfdv.get_feature(schema, "Purchased").not_in_environment.append("SERVING")
tfdv.display_schema(schema=schema)
!mkdir -p {SCHEMA_DIRECTORY}
tfdv.write_schema_text(schema, SCHEMA_FILE_NAME)
# loading the updated schema using the importer node.
schema_importer = tfx.dsl.Importer(
source_uri=str(SCHEMA_DIRECTORY),
artifact_type=tfx.types.standard_artifacts.Schema
).with_id("schema_importer")
context.run(schema_importer)
context.show(schema_importer.outputs['result'])
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_importer.outputs['result']
)
context.run(example_validator)
context.show(example_validator.outputs['anomalies'])
%%writefile {CONSTANTS_MODULE_PATH}
LABEL = 'Purchased'
%%writefile {TRANSFORM_MODULE_PATH}
import tensorflow as tf
import tensorflow_transform as tft
import constants
LABEL = constants.LABEL
def preprocessing_fn(inputs):
outputs = dict()
outputs['Age'] = inputs['Age']
outputs['EstimatedSalary'] = inputs['EstimatedSalary']
# converting the `Gender` into label encoded column.
outputs['Gender'] = tf.cast(tf.equal(inputs['Gender'], 'male'), tf.int64)
outputs[LABEL] = inputs[LABEL]
return outputs
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_importer.outputs['result'],
module_file=os.path.abspath(TRANSFORM_MODULE_PATH),
)
context.run(transform, enable_cache=False)
train_uri = os.path.join(
transform.outputs['transformed_examples'].get()[0].uri,
'Split-train'
)
tfrecord_filenames = [
os.path.join(train_uri, name) for name in os.listdir(train_uri)
]
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type='GZIP')
for tfrecord in dataset.take(3):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
print(example)
%%writefile {TRAINER_MODULE_PATH}
import tensorflow as tf
import tensorflow_decision_forests as tfdf
import tensorflow_transform as tft
from absl import logging
from tensorflow.keras import layers, Model, optimizers, losses, metrics
from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from typing import List, Text
import constants
LABEL = constants.LABEL
BATCH_SIZE = 32
EPOCHS = 50
def _input_fn(
file_pattern: List[Text],
data_accessor: tfx.components.DataAccessor,
tf_transform_output: tft.TFTransformOutput,
batch_size: int,
) -> tf.data.Dataset:
"""
Generates a dataset of features that can be used to train
and evaluate the model.
Args:
file_pattern: List of paths or patterns of input data files.
data_accessor: An instance of DataAccessor that we can use to
convert the input to a RecordBatch.
tf_transform_output: The transformation output.
batch_size: The number of consecutive elements that we should
combine in a single batch.
Returns:
A dataset that contains a tuple of (features, indices) where
features is a dictionary of Tensors, and indices is a single
Tensor of label indices.
"""
dataset = data_accessor.tf_dataset_factory(
file_pattern,
tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
schema=tf_transform_output.raw_metadata.schema,
)
tft_layer = tf_transform_output.transform_features_layer()
def apply_transform(raw_features):
transformed_features = tft_layer(raw_features)
transformed_label = transformed_features.pop(LABEL)
return transformed_features, transformed_label
return dataset.map(apply_transform).repeat()
def _get_serve_tf_examples_fn(model, tf_transform_output):
"""
Returns a function that parses a serialized tf.Example and applies
the transformations during inference.
Args:
model: The model that we are serving.
tf_transform_output: The transformation output that we want to
include with the model.
"""
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")])
def serve_tf_examples_fn(serialized_tf_examples):
feature_spec = tf_transform_output.raw_feature_spec()
required_feature_spec = {
k: v for k, v in feature_spec.items() if k != LABEL
}
parsed_features = tf.io.parse_example(
serialized_tf_examples,
required_feature_spec
)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_tf_examples_fn
def _model() -> tf.keras.Model:
inputs = [
layers.Input(shape=(1,), name="Age"),
layers.Input(shape=(1,), name="EstimatedSalary"),
layers.Input(shape=(1,), name="Gender")
]
x = layers.concatenate(inputs)
x = layers.Dense(8, activation="relu")(x)
x = layers.Dense(8, activation="relu")(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=optimizers.Adam(1e-2),
loss="binary_crossentropy",
metrics=[metrics.BinaryAccuracy()],
)
model.summary(print_fn=logging.info)
return model
def run_fn(fn_args: tfx.components.FnArgs):
"""
The callback function that will be called by the Trainer component
to train the model using the suplied arguments.
Args:
fn_args: A collection of name/value pairs representing the
arguments to train the model.
"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=BATCH_SIZE,
)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=BATCH_SIZE,
)
model = _model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
epochs=EPOCHS
)
# We need to modify the default signature to include the transform layer in
# the computational graph.
signatures = {
"serving_default": _get_serve_tf_examples_fn(model, tf_transform_output),
}
model.save(fn_args.serving_model_dir, save_format="tf", signatures=signatures)
trainer = tfx.components.Trainer(
examples=example_gen.outputs["examples"],
transform_graph=transform.outputs["transform_graph"],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=5),
module_file=os.path.abspath(TRAINER_MODULE_PATH),
)
context.run(trainer, enable_cache=False)
eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(
signature_name="serving_default",
preprocessing_function_names=['tft_layer'],
label_key="Purchased",
)
],
metrics_specs = [
tfma.MetricsSpec(
per_slice_thresholds={
"binary_accuracy": tfma.PerSliceMetricThresholds(
thresholds=[
tfma.PerSliceMetricThreshold(
slicing_specs=[tfma.SlicingSpec()],
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={"value":0.7}
),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={"value": -1e-10},
),
),
)
]
),
}
)
],
slicing_specs=[
tfma.SlicingSpec(),
tfma.SlicingSpec(feature_keys=["Gender"])
]
)
model_resolver = tfx.dsl.Resolver(
strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
model_blessings=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing),
).with_id("latest_blessed_model_resolver")
context.run(model_resolver)
evaluator = tfx.components.Evaluator(
examples=example_gen.outputs["examples"],
model=trainer.outputs["model"],
eval_config=eval_config,
baseline_model=model_resolver.outputs["model"],
)
context.run(evaluator, enable_cache=False)
pusher = tfx.components.Pusher(
model=trainer.outputs["model"],
model_blessing=evaluator.outputs["blessing"],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=str(MODEL_DIRECTORY)
)
),
)
context.run(pusher)
def get_inference_fn(model_directory):
model_directories = (d for d in os.scandir(model_directory) if d.is_dir())
model_path = max(model_directories, key=lambda i: int(i.name)).path
loaded_model = tf.keras.models.load_model(model_path)
return loaded_model.signatures["serving_default"]
inference_fn = get_inference_fn(MODEL_DIRECTORY)
result = inference_fn(examples=tf.constant(_examples(test_df)))
print(result["output_0"].numpy())
%%writefile {SCHEMA_FILE_NAME}
feature {
name: "Gender"
type: BYTES
domain: "Gender"
presence {
min_fraction: 1.0
min_count: 1
}
shape {
dim {
size: 1
}
}
}
feature {
name: "Age"
type: INT
presence {
min_fraction: 1.0
min_count: 1
}
shape {
dim {
size: 1
}
}
}
feature {
name: "EstimatedSalary"
type: INT
presence {
min_fraction: 1.0
min_count: 1
}
shape {
dim {
size: 1
}
}
}
feature {
name: "Purchased"
type: INT
bool_domain {
}
presence {
min_fraction: 1.0
min_count: 1
}
not_in_environment: "SERVING"
shape {
dim {
size: 1
}
}
}
string_domain {
name: "Gender"
value: "Female"
value: "Male"
}
default_environment: "TRAINING"
default_environment: "SERVING"
import tensorflow_model_analysis as tfma
def create_pipeline(
pipeline_name: str,
pipeline_directory: str,
data_directory: str,
schema_path: str,
model_directory: str,
metadata_path: str,
transform_module_path: str,
trainer_module_path: str,
)-> tfx.dsl.Pipeline:
output = tfx.proto.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
tfx.proto.SplitConfig.Split(name="train", hash_buckets=3),
tfx.proto.SplitConfig.Split(name="eval", hash_buckets=1)
]))
example_gen = tfx.components.CsvExampleGen(input_base=SAVED_DATA, output_config=output)
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples']
)
schema_importer = tfx.dsl.Importer(
source_uri=str(SCHEMA_DIRECTORY),
artifact_type=tfx.types.standard_artifacts.Schema
).with_id("schema_importer")
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_importer.outputs['result']
)
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_importer.outputs['result'],
module_file=os.path.abspath(TRANSFORM_MODULE_PATH),
)
trainer = tfx.components.Trainer(
examples=example_gen.outputs["examples"],
transform_graph=transform.outputs["transform_graph"],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=5),
module_file=os.path.abspath(TRAINER_MODULE_PATH),
)
eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(
signature_name="serving_default",
preprocessing_function_names=['tft_layer'],
label_key="Purchased",
)
],
metrics_specs = [
tfma.MetricsSpec(
per_slice_thresholds={
"binary_accuracy": tfma.PerSliceMetricThresholds(
thresholds=[
tfma.PerSliceMetricThreshold(
slicing_specs=[tfma.SlicingSpec()],
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={"value":0.7}
),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={"value": -1e-10},
),
),
)
]
),
}
)
],
slicing_specs=[
tfma.SlicingSpec(),
tfma.SlicingSpec(feature_keys=["Gender"])
]
)
model_resolver = tfx.dsl.Resolver(
strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
model_blessings=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing),
).with_id("latest_blessed_model_resolver")
evaluator = tfx.components.Evaluator(
examples=example_gen.outputs["examples"],
model=trainer.outputs["model"],
eval_config=eval_config,
baseline_model=model_resolver.outputs["model"],
)
pusher = tfx.components.Pusher(
model=trainer.outputs["model"],
model_blessing=evaluator.outputs["blessing"],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=str(MODEL_DIRECTORY)
)
),
)
components = [
example_gen,
statistics_gen,
schema_importer,
example_validator,
transform,
trainer,
model_resolver,
evaluator,
pusher
]
return tfx.dsl.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_directory,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
metadata_path
),
components=components,
)
tfx.orchestration.LocalDagRunner().run(
create_pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_directory=str(PIPELINE_DIRECTORY),
data_directory=DATA_DIRECTORY,
schema_path=str(SCHEMA_DIRECTORY),
model_directory=str(MODEL_DIRECTORY),
metadata_path=str(METADATA_PATH),
transform_module_path=TRANSFORM_MODULE_PATH,
trainer_module_path=TRAINER_MODULE_PATH
)
)
inference_fn = get_inference_fn(MODEL_DIRECTORY)
result = inference_fn(examples=tf.constant(_examples(test_df)))
print(result["output_0"].numpy())