# !pip install -q -U tfx==1.4.0
# !pip install tensorflow==2.7.0
# !pip install tensorflow_decision_forests==0.2.1
# creating required directories
!mkdir -p saved_data
!mkdir -p model
!mkdir -p data
import tempfile
import tensorflow as tf
import urllib.request
import os
import pandas as pd
import shutil
import tensorflow_data_validation as tfdv
import tensorflow_model_analysis as tfma
import tensorflow_decision_forests as tfdf
from absl import logging
from pathlib import Path
from tfx import v1 as tfx
from tensorflow_metadata.proto.v0 import schema_pb2
from tfx.proto import example_gen_pb2
print(f"Tensorflow Version: {tf.__version__}")
print(f"TFX Version: {tfx.__version__}")
print(f"TFDF Version: {tfdf.__version__}")
print(f"Tensorflow Data Validation Version: {tfdv.__version__}")
logging.set_verbosity(logging.INFO)
DATA_DIRECTORY = 'data'
DATA_SOURCE_PATH = Path(DATA_DIRECTORY) / 'Social_Network_Ads.csv'
SAVED_DATA = 'saved_data'
DATA_TRAIN_FILENAME = Path(SAVED_DATA) / 'train.csv'
DATA_TEST_FILENAME = Path(SAVED_DATA) / 'test.csv'
PIPELINE_NAME = 'sample-pipeline'
PIPELINE_DIRECTORY = os.path.join(Path('pipelines'), PIPELINE_NAME)
METADATA_PATH = Path("metadata") / PIPELINE_NAME / "metadata.db"
SCHEMA_DIRECTORY = os.path.join(PIPELINE_DIRECTORY, 'schema')
SCHEMA_FILE_NAME = str(os.path.join(SCHEMA_DIRECTORY, 'schema.pbtxt'))
MODEL_DIRECTORY = Path('model')
# Module Paths
CONSTANTS_MODULE_PATH = 'constants.py'
TRANSFORM_MODULE_PATH = 'transform.py'
TRAINER_MODULE_PATH = 'trainer.py'
data_df = pd.read_csv(DATA_SOURCE_PATH)
# splitting the data for training and testing
data_df = data_df.sample(frac=1)
train_df = data_df[: int(len(data_df) * 0.7)]
test_df = data_df[int(len(data_df) * 0.7): ]
# removing the undesired columns from all the datasets
datasets = [train_df, test_df]
drop_columns = ['User ID']
for dataset in datasets:
    dataset.drop(drop_columns, axis=1, inplace=True)
train_df.to_csv(DATA_TRAIN_FILENAME, index=False)
test_df.to_csv(DATA_TEST_FILENAME, index=False)
# peeking at the data
train_df.info()
# creating the useful conversion functions
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(str(value), encoding="raw_unicode_escape")]))
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# This function will be required to convert out test set 
# to compatible schema types for inferencing
def _examples(df):
    examples = []
    for index, row in df.iterrows():
        features = {
            "Gender": _bytes_feature(row['Gender']),
            "Age": _int64_feature(row['Age']),
            "EstimatedSalary": _int64_feature(row['EstimatedSalary']),
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=features))
        examples.append(example_proto.SerializeToString())
    return examples
from tfx.orchestration.experimental.interactive.interactive_context import (InteractiveContext)
context = InteractiveContext()
output = tfx.proto.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
                                                     tfx.proto.SplitConfig.Split(name="train", hash_buckets=3),
                                                     tfx.proto.SplitConfig.Split(name="eval", hash_buckets=1)
    ]))
example_gen = tfx.components.CsvExampleGen(input_base=SAVED_DATA, output_config=output)
context.run(example_gen)
statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples']
)
context.run(statistics_gen)
context.show(statistics_gen.outputs['statistics'])
schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=True
)
context.run(schema_gen)
context.show(schema_gen.outputs['schema'])
schema = tfdv.load_schema_text(
    os.path.join(schema_gen.outputs['schema']._artifacts[0].uri, "schema.pbtxt")
)
# adding the needed environments
schema.default_environment.append("TRAINING")
schema.default_environment.append("SERVING")
# removing the `Purchased` column from the serving environment
tfdv.get_feature(schema, "Purchased").not_in_environment.append("SERVING")
tfdv.display_schema(schema=schema)
!mkdir -p {SCHEMA_DIRECTORY}
tfdv.write_schema_text(schema, SCHEMA_FILE_NAME)
# loading the updated schema using the importer node.
schema_importer = tfx.dsl.Importer(
    source_uri=str(SCHEMA_DIRECTORY),
    artifact_type=tfx.types.standard_artifacts.Schema
).with_id("schema_importer")
context.run(schema_importer)
context.show(schema_importer.outputs['result'])
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_importer.outputs['result']
)
context.run(example_validator)
context.show(example_validator.outputs['anomalies'])
%%writefile {CONSTANTS_MODULE_PATH}
LABEL = 'Purchased'
%%writefile {TRANSFORM_MODULE_PATH}
import tensorflow as tf
import tensorflow_transform as tft
import constants
LABEL = constants.LABEL
def preprocessing_fn(inputs):
    outputs = dict()
    outputs['Age'] = inputs['Age']
    outputs['EstimatedSalary'] = inputs['EstimatedSalary']
    # converting the `Gender` into label encoded column.
    outputs['Gender'] = tf.cast(tf.equal(inputs['Gender'], 'male'), tf.int64)
    outputs[LABEL] = inputs[LABEL]
    return outputs
transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_importer.outputs['result'],
    module_file=os.path.abspath(TRANSFORM_MODULE_PATH),
)
context.run(transform, enable_cache=False)
train_uri = os.path.join(
    transform.outputs['transformed_examples'].get()[0].uri,
    'Split-train'
)
tfrecord_filenames = [
                      os.path.join(train_uri, name) for name in os.listdir(train_uri)
]
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type='GZIP')
for tfrecord in dataset.take(3):
    serialized_example = tfrecord.numpy()
    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    print(example)
%%writefile {TRAINER_MODULE_PATH}
import tensorflow as tf
import tensorflow_decision_forests as tfdf
import tensorflow_transform as tft
from absl import logging
from tensorflow.keras import layers, Model, optimizers, losses, metrics
from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from typing import List, Text
import constants
LABEL = constants.LABEL
BATCH_SIZE = 32
EPOCHS = 50
def _input_fn(
    file_pattern: List[Text],
    data_accessor: tfx.components.DataAccessor,
    tf_transform_output: tft.TFTransformOutput,
    batch_size: int,
) -> tf.data.Dataset:
    """
    Generates a dataset of features that can be used to train
    and evaluate the model.
    Args:
        file_pattern: List of paths or patterns of input data files.
        data_accessor: An instance of DataAccessor that we can use to
            convert the input to a RecordBatch.
        tf_transform_output: The transformation output.
        batch_size: The number of consecutive elements that we should
            combine in a single batch.
    Returns:
        A dataset that contains a tuple of (features, indices) where 
            features is a dictionary of Tensors, and indices is a single
            Tensor of label indices.
    """
    dataset = data_accessor.tf_dataset_factory(
        file_pattern,
        tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
        schema=tf_transform_output.raw_metadata.schema,
    )
    tft_layer = tf_transform_output.transform_features_layer()
    def apply_transform(raw_features):
        transformed_features = tft_layer(raw_features)
        transformed_label = transformed_features.pop(LABEL)
        return transformed_features, transformed_label
    
    return dataset.map(apply_transform).repeat()
def _get_serve_tf_examples_fn(model, tf_transform_output):
    """
    Returns a function that parses a serialized tf.Example and applies
    the transformations during inference.
    Args:
        model: The model that we are serving.
        tf_transform_output: The transformation output that we want to 
            include with the model.
    """
    
    model.tft_layer = tf_transform_output.transform_features_layer()
    def serve_tf_examples_fn(serialized_tf_examples):
        feature_spec = tf_transform_output.raw_feature_spec()
        
        required_feature_spec = {
            k: v for k, v in feature_spec.items() if k != LABEL
        }
        parsed_features = tf.io.parse_example(
            serialized_tf_examples,
            required_feature_spec
        )
        transformed_features = model.tft_layer(parsed_features)
        return model(transformed_features)
    return serve_tf_examples_fn
def _model() -> tf.keras.Model:
    inputs = [
              layers.Input(shape=(1,), name="Age"),
              layers.Input(shape=(1,), name="EstimatedSalary"),
              layers.Input(shape=(1,), name="Gender")
    ]
    x = layers.concatenate(inputs)
    x = layers.Dense(8, activation="relu")(x)
    x = layers.Dense(8, activation="relu")(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=optimizers.Adam(1e-2),
        loss="binary_crossentropy",
        metrics=[metrics.BinaryAccuracy()],
    )
    model.summary(print_fn=logging.info)
    return model
def run_fn(fn_args: tfx.components.FnArgs):
    """
    The callback function that will be called by the Trainer component
    to train the model using the suplied arguments.
    Args:
        fn_args: A collection of name/value pairs representing the 
            arguments to train the model.
    """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
    train_dataset = _input_fn(
        fn_args.train_files,
        fn_args.data_accessor,
        tf_transform_output,
        batch_size=BATCH_SIZE,
    )
    eval_dataset = _input_fn(
        fn_args.eval_files,
        fn_args.data_accessor,
        tf_transform_output,
        batch_size=BATCH_SIZE,
    )
    model = _model()
    model.fit(
        train_dataset,
        steps_per_epoch=fn_args.train_steps,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps,
        epochs=EPOCHS
    )
    # We need to modify the default signature to include the transform layer in 
    # the computational graph.
    signatures = {
        "serving_default": _get_serve_tf_examples_fn(model, tf_transform_output),
    }
    model.save(fn_args.serving_model_dir, save_format="tf", signatures=signatures)
trainer = tfx.components.Trainer(
    examples=example_gen.outputs["examples"],
    transform_graph=transform.outputs["transform_graph"],
    train_args=tfx.proto.TrainArgs(num_steps=100),
    eval_args=tfx.proto.EvalArgs(num_steps=5),
    module_file=os.path.abspath(TRAINER_MODULE_PATH),
)
context.run(trainer, enable_cache=False)
eval_config = tfma.EvalConfig(
    model_specs=[
                 tfma.ModelSpec(
                     signature_name="serving_default",
                     preprocessing_function_names=['tft_layer'],
                     label_key="Purchased",
                 )
    ],
    metrics_specs = [
                     tfma.MetricsSpec(
                         per_slice_thresholds={
                             "binary_accuracy": tfma.PerSliceMetricThresholds(
                                 thresholds=[
                                             tfma.PerSliceMetricThreshold(
                                                 slicing_specs=[tfma.SlicingSpec()],
                                                 threshold=tfma.MetricThreshold(
                                                     value_threshold=tfma.GenericValueThreshold(
                                                         lower_bound={"value":0.7}
                                                     ),
                                                     change_threshold=tfma.GenericChangeThreshold(
                                                         direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                                                         absolute={"value": -1e-10},
                                                     ),
                                                 ),
                                             )
                                 ]
                             ),
                         }
                     )
    ],
    slicing_specs=[
                   tfma.SlicingSpec(),
                   tfma.SlicingSpec(feature_keys=["Gender"])
    ]
)
model_resolver = tfx.dsl.Resolver(
    strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
    model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
    model_blessings=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing),
).with_id("latest_blessed_model_resolver")
context.run(model_resolver)
evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs["examples"],
    model=trainer.outputs["model"],
    eval_config=eval_config,
    baseline_model=model_resolver.outputs["model"],
)
context.run(evaluator, enable_cache=False)
pusher = tfx.components.Pusher(
    model=trainer.outputs["model"],
    model_blessing=evaluator.outputs["blessing"],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=str(MODEL_DIRECTORY)
        )
    ),
)
context.run(pusher)
def get_inference_fn(model_directory):
    model_directories = (d for d in os.scandir(model_directory) if d.is_dir())
    model_path = max(model_directories, key=lambda i: int(i.name)).path
    loaded_model = tf.keras.models.load_model(model_path)
    return loaded_model.signatures["serving_default"]
inference_fn = get_inference_fn(MODEL_DIRECTORY)
result = inference_fn(examples=tf.constant(_examples(test_df)))
print(result["output_0"].numpy())
%%writefile {SCHEMA_FILE_NAME}
feature {
  name: "Gender"
  type: BYTES
  domain: "Gender"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "Age"
  type: INT
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "EstimatedSalary"
  type: INT
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "Purchased"
  type: INT
  bool_domain {
  }
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  not_in_environment: "SERVING"
  shape {
    dim {
      size: 1
    }
  }
}
string_domain {
  name: "Gender"
  value: "Female"
  value: "Male"
}
default_environment: "TRAINING"
default_environment: "SERVING"
import tensorflow_model_analysis as tfma
def create_pipeline(
    pipeline_name: str,
    pipeline_directory: str,
    data_directory: str,
    schema_path: str,
    model_directory: str,
    metadata_path: str,
    transform_module_path: str,
    trainer_module_path: str,
    )-> tfx.dsl.Pipeline:
    output = tfx.proto.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
                                                     tfx.proto.SplitConfig.Split(name="train", hash_buckets=3),
                                                     tfx.proto.SplitConfig.Split(name="eval", hash_buckets=1)
    ]))
    example_gen = tfx.components.CsvExampleGen(input_base=SAVED_DATA, output_config=output)
    statistics_gen = tfx.components.StatisticsGen(
        examples=example_gen.outputs['examples']
    )
    schema_importer = tfx.dsl.Importer(
        source_uri=str(SCHEMA_DIRECTORY),
        artifact_type=tfx.types.standard_artifacts.Schema
    ).with_id("schema_importer")
    example_validator = tfx.components.ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_importer.outputs['result']
    )
    transform = tfx.components.Transform(
        examples=example_gen.outputs['examples'],
        schema=schema_importer.outputs['result'],
        module_file=os.path.abspath(TRANSFORM_MODULE_PATH),
    )
    trainer = tfx.components.Trainer(
        examples=example_gen.outputs["examples"],
        transform_graph=transform.outputs["transform_graph"],
        train_args=tfx.proto.TrainArgs(num_steps=100),
        eval_args=tfx.proto.EvalArgs(num_steps=5),
        module_file=os.path.abspath(TRAINER_MODULE_PATH),
    )
    eval_config = tfma.EvalConfig(
        model_specs=[
                    tfma.ModelSpec(
                        signature_name="serving_default",
                        preprocessing_function_names=['tft_layer'],
                        label_key="Purchased",
                    )
        ],
        metrics_specs = [
                        tfma.MetricsSpec(
                            per_slice_thresholds={
                                "binary_accuracy": tfma.PerSliceMetricThresholds(
                                    thresholds=[
                                                tfma.PerSliceMetricThreshold(
                                                    slicing_specs=[tfma.SlicingSpec()],
                                                    threshold=tfma.MetricThreshold(
                                                        value_threshold=tfma.GenericValueThreshold(
                                                            lower_bound={"value":0.7}
                                                        ),
                                                        change_threshold=tfma.GenericChangeThreshold(
                                                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                                                            absolute={"value": -1e-10},
                                                        ),
                                                    ),
                                                )
                                    ]
                                ),
                            }
                        )
        ],
        slicing_specs=[
                    tfma.SlicingSpec(),
                    tfma.SlicingSpec(feature_keys=["Gender"])
        ]
    )
    model_resolver = tfx.dsl.Resolver(
        strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
        model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
        model_blessings=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing),
    ).with_id("latest_blessed_model_resolver")
    evaluator = tfx.components.Evaluator(
        examples=example_gen.outputs["examples"],
        model=trainer.outputs["model"],
        eval_config=eval_config,
        baseline_model=model_resolver.outputs["model"],
    )
    pusher = tfx.components.Pusher(
        model=trainer.outputs["model"],
        model_blessing=evaluator.outputs["blessing"],
        push_destination=tfx.proto.PushDestination(
            filesystem=tfx.proto.PushDestination.Filesystem(
                base_directory=str(MODEL_DIRECTORY)
            )
        ),
    )
    components = [
                  example_gen,
                  statistics_gen,
                  schema_importer,
                  example_validator,
                  transform,
                  trainer,
                  model_resolver,
                  evaluator,
                  pusher
    ]
    return tfx.dsl.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_directory,
        metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
            metadata_path
            ),
            components=components,
    )
tfx.orchestration.LocalDagRunner().run(
    create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_directory=str(PIPELINE_DIRECTORY),
        data_directory=DATA_DIRECTORY,
        schema_path=str(SCHEMA_DIRECTORY),
        model_directory=str(MODEL_DIRECTORY),
        metadata_path=str(METADATA_PATH),
        transform_module_path=TRANSFORM_MODULE_PATH,
        trainer_module_path=TRAINER_MODULE_PATH
    )
)
inference_fn = get_inference_fn(MODEL_DIRECTORY)
result = inference_fn(examples=tf.constant(_examples(test_df)))
print(result["output_0"].numpy())