In [3]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
import os

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  return image, label

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])
  model.compile(
      loss=tf.keras.losses.sparse_categorical_crossentropy,
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

def train():
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

    tfds.disable_progress_bar()
    data_dir = '/mnt/tensorflow_datasets'
    datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True, data_dir=data_dir, download=False)
    train_datasets = datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)

    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    
    with strategy.scope():
      multi_worker_model = build_and_compile_cnn_model()

    multi_worker_model.fit(x=train_datasets, epochs=10)

    # multi_worker_model.save('/mnt/jupyter/saved_model')


def fairing_run():
    from kubeflow import fairing
    from kubeflow.fairing.kubernetes.utils import volume_mounts
    from kubeflow.fairing.kubernetes.utils import get_resource_mutator
    import uuid

    # WARNING: Kubeflow Fairing does not support docker registries using a self-signed TLS certificate, 
    #          certificate chaining nor insecure (plaintext HTTP) registries.
    # DOCKER_REGISTRY = 'repo.acp.kt.co.kr'
    project_name = 'mnist-mwms'

    docker_registry = 'repo.chelsea.kt.co.kr'
    # base_image = 'tensorflow/tensorflow:2.5.0-gpu'
    #  + pip install tensorflow_datasets
    base_image = 'repo.chelsea.kt.co.kr/agp/tensorflow-custom:2.5.0-gpu'
    image_name = 'agp/' + project_name

    tfjob_name = f'{project_name}-training-{uuid.uuid4().hex[:4]}'
    num_chief = 1    # number of Chief in TFJob

    k8s_pvc_name = f'{project_name}-pvc'
    mount_name = "/mnt"
    
    fairing.config.set_builder(name='append',
                               registry=docker_registry,
                               base_image=base_image,
                               image_name=image_name)
    fairing.config.set_deployer(name='tfjob',  
                                job_name=tfjob_name,
                                chief_count=num_chief, 
                                worker_count=num_workers,
                                pod_spec_mutators=[volume_mounts(volume_type='pvc', volume_name=k8s_pvc_name, mount_path=mount_name),
                                                   get_resource_mutator(gpu=gpus_per_worker, gpu_vendor='nvidia')],
                                stream_log=False) 
    fairing.config.run()
    
if __name__ == '__main__':
    num_workers = 2      # number of Worker in TFJob
    gpus_per_worker = 1  # number of GPUs for Worker

    if os.getenv('FAIRING_RUNTIME', None) is None:        
        fairing_run()    
    else:
        BUFFER_SIZE = 10000
        GLOBAL_BATCH_SIZE = 100 * num_workers * gpus_per_worker
        
        train()

JSONDecodeError: Expecting value: line 1 column 1 (char 0)