인공지능/Tensorflow Extended

[살아 움직이는 머신러닝 파이프라인 설계] 6. 모델 학습

TFX에서 학습할 모델 정의하기

TFX에서 텐서플로우로 직접 만든 모델이나 케라스 모델을 모델 파이프라인에서 학습시킬수 있다. 이때 케라스 모델의 입력 레이어는 name에 전처리된 특성의 이름이 들어가야 자동으로 해당하는 입력 레이어에 맞춰 들어간다.

Trainer 컴포넌트에게 필요한 입력

머신러닝 파이프라인에서 모델 학습을 진행할 때 Trainer 컴포넌트를 추가한다. 이때 Trainer에 5개의 입력을 집어넣어 정의한다.

  • 데이터 스키마
  • 전처리된 데이터
  • 전처리 그래프
  • 학습 하이퍼 파라미터
  • 학습 방식을 정의한 run_fn() 이 저장된 소스코드 파일

run_fn()

run_fn()은 TFT의 preprocessing_fn()처럼 사용자가 직접 정의하는 함수다. run_fn()Trainer가 모델을 학습시킬때 실행되는 함수로 다음과 같은 과정이 포함되어야 한다.

  • 학습 및 검증 데이터셋 불러오기
  • 모델 정의 및 컴파일
  • 학습
  • 학습된 모델 반환

run_fn()의 파라미터는 파이썬 딕셔너리 단 하나만 받는다. 이 파라미터에서 전처리된 데이터셋이나 하이퍼 파라미터를 가져와야 한다. 다만 모델은 run_fn()에서 생성하고 학습까지하기 때문에 파라미터에서는 가져오지 않는다.
학습 루프 자체를 run_fn()에서 직접 실행하기 때문에 텐서플로우가 아닌 파이토치나 scikit-learn 등과 같은 다른 머신러닝 라이브러리를 사용할 수 있다.

머신러닝 파이프라인에서 하이퍼 파라미터 튜닝하기

머신러닝 모델이 학습 성능과 예측 성능 모두 뛰어나려면 적합한 하이퍼 파라미터를 사용해야한다. 그렇기 때문에 여러 하이퍼 파라미터 값들을 실험하면서 어떤 값이 최고의 성능이 나오는지 실험한다. 이를 하이퍼 파라미터 튜닝이라 부른다.
하이퍼 파라미터 튜닝에는 대표적으로 두가지 전략이 있다. 첫번째 전략인 그리드 검색(Grid Search) 은 정해진 간격으로 하이퍼 파라미터 값을 정해 실험하는 방식을 말하고, 두번째 전략인 랜덤 검색(Random Search) 은 말 그대로 하이퍼 파라미터 값을 일정 범위 내에서 샘플을 뽑아 실험하는 방식을 말한다.
책이 쓰여질 당시에는 하이퍼 파라미터 튜닝을 하는 컴포넌트인 Tunner가 막 출시된 참이기 때문에 사용법에 대한 내용은 나오지 않았다. 대신 학습할 때 여러개의 모델을 학습한 후 그 중 가장 효과가 좋은 하이퍼 파라미터 조합을 골라 다시 학습하는 방법을 안내한다.