티스토리 뷰

TFX Trainer

TFX Trainer 컴포넌트는 파이프라인의 학습 단계를 처리한다. 이 컴포넌트의 주요 프로세스에 대해서 정리하면 아래와 같다.

  1. run_fn() 함수를 찾아서 학습 프로세스를 실행하는 진입점으로 사용한다. 
  2. run_fn() 함수는 input_fn() 함수를 찾아 사용할 데이터 로딩을 일괄적으로 수행한다. 이 함수는 transform 단계에서 생성한 압축되고 전처리된 데이터셋을 로드할 수 있다.
  3. get_model()을 사용하여 컴파일된 케라스 모델을 가져온다.
  4. fit함수로 학습한다.
  5. 학습된 모델을 SavedModel 형식으로 저장하고 내보낸다.

예제 프로젝트의 run_fn() 함수 코드는 아래와 같다.

def run_fn(fn_args):

    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
    # input_fn 을 호출하여 데이터 생성기를 가져옵니다.
    train_dataset = input_fn(fn_args.train_files, tf_transform_output)
    eval_dataset = input_fn(fn_args.eval_files, tf_transform_output)

    # get_model함수를 호출하여 컴파일된 케라스 모델을 가져옵니다.
    model = get_model()
    model.fit(
        train_dataset,
        steps_per_epoch=fn_args.train_steps,
        validation_data=eval_dataset,
        # Trainer 컴포넌트가 통과한 학습 및 평가 단계 수를 사용하여 모델을 학습합니다.
        validation_steps=fn_args.eval_steps)

    # 나중에 설명할 서빙 피처를 포함하는 모델 서명을 정의합니다.
    signatures = {
        'serving_default':
            _get_serve_tf_examples_fn(
                model,
                tf_transform_output).get_concrete_function(
                tf.TensorSpec(
                    shape=[None],
                    dtype=tf.string,
                    name='examples')
            )
    }
    model.save(fn_args.serving_model_dir,
               save_format='tf', signatures=signatures)

컴포넌트 실행 코드는 아래와 같다.

trainer = Trainer(
    module_file=os.path.abspath("trainer.py"),
    examples=transform.outputs["transformed_examples"],
    transform_graph=transform.outputs["transform_graph"],
    schema=schema_gen.outputs["schema"],
    train_args=tfx.proto.TrainArgs(num_steps=100),
    eval_args=tfx.proto.EvalArgs(num_steps=50))
context.run(trainer)

transform의 outputs를 사용하여 변환된 examples와 그래프를 입력으로 사용한다. 모델학습과 내보내기가 완료되면 모델의 경로를 메타데이터스토어에 등록한다.

 

파이프라인 외부에서 SavedModel 사용하기

아래의 코드를 이용하여 모델을 로드하면 SavedModel을 사용할 수 있다.

import tensorflow as tf
model_path = trainer.outputs["model"].get()[0].uri + "/Format-Serving"
model = tf.saved_model.load(export_dir=model_path)
predict_fn = model.signatures["serving_default"]

run_fn()함수에 콜백을 추가하여 학습을 기록하면 tensorboard를 사용하여 시각화하여 분석할 수 있다.

    # tensorboard를 위한 log기록
    log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='batch')

TFX model을 텐서보드로 시각화한 결과 on jupyter notebook

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/02   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28
글 보관함