티스토리 뷰
TFX Trainer
TFX Trainer 컴포넌트는 파이프라인의 학습 단계를 처리한다. 이 컴포넌트의 주요 프로세스에 대해서 정리하면 아래와 같다.
- run_fn() 함수를 찾아서 학습 프로세스를 실행하는 진입점으로 사용한다.
- run_fn() 함수는 input_fn() 함수를 찾아 사용할 데이터 로딩을 일괄적으로 수행한다. 이 함수는 transform 단계에서 생성한 압축되고 전처리된 데이터셋을 로드할 수 있다.
- get_model()을 사용하여 컴파일된 케라스 모델을 가져온다.
- fit함수로 학습한다.
- 학습된 모델을 SavedModel 형식으로 저장하고 내보낸다.
예제 프로젝트의 run_fn() 함수 코드는 아래와 같다.
def run_fn(fn_args):
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
# input_fn 을 호출하여 데이터 생성기를 가져옵니다.
train_dataset = input_fn(fn_args.train_files, tf_transform_output)
eval_dataset = input_fn(fn_args.eval_files, tf_transform_output)
# get_model함수를 호출하여 컴파일된 케라스 모델을 가져옵니다.
model = get_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
# Trainer 컴포넌트가 통과한 학습 및 평가 단계 수를 사용하여 모델을 학습합니다.
validation_steps=fn_args.eval_steps)
# 나중에 설명할 서빙 피처를 포함하는 모델 서명을 정의합니다.
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model,
tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name='examples')
)
}
model.save(fn_args.serving_model_dir,
save_format='tf', signatures=signatures)
컴포넌트 실행 코드는 아래와 같다.
trainer = Trainer(
module_file=os.path.abspath("trainer.py"),
examples=transform.outputs["transformed_examples"],
transform_graph=transform.outputs["transform_graph"],
schema=schema_gen.outputs["schema"],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=50))
context.run(trainer)
transform의 outputs를 사용하여 변환된 examples와 그래프를 입력으로 사용한다. 모델학습과 내보내기가 완료되면 모델의 경로를 메타데이터스토어에 등록한다.
파이프라인 외부에서 SavedModel 사용하기
아래의 코드를 이용하여 모델을 로드하면 SavedModel을 사용할 수 있다.
import tensorflow as tf
model_path = trainer.outputs["model"].get()[0].uri + "/Format-Serving"
model = tf.saved_model.load(export_dir=model_path)
predict_fn = model.signatures["serving_default"]
run_fn()함수에 콜백을 추가하여 학습을 기록하면 tensorboard를 사용하여 시각화하여 분석할 수 있다.
# tensorboard를 위한 log기록
log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='batch')
'Study > MLOps' 카테고리의 다른 글
도커와 쿠버네티스 (4) - 쿠버네티스 with minikube (0) | 2021.12.08 |
---|---|
도커와 쿠버네티스 (3) - 도커 이미지 (0) | 2021.12.01 |
도커와 쿠버네티스 (2) - 도커 설치 (0) | 2021.11.26 |
도커와 쿠버네티스 (1) - 개념과 실습환경 (0) | 2021.11.25 |
Building ML Pipelines 따라잡기 (5) - 데이터 전처리 (0) | 2021.11.24 |
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- 머신러닝파이프라인
- torch
- ML
- DDUX
- productowner
- productmanager
- 딥러닝
- nlp
- 쿠버네티스
- mlpipeline
- MLOps
- 전처리
- container
- Tennis
- PM
- deeplearning
- productresearch
- 도커
- Kubernetes
- 자연어처리
- 인공지능
- Bert
- docker
- pmpo
- 파이프라인
- dl
- 머신러닝
- Oreilly
- 스타트업
- PO
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
글 보관함