tensorflow :: 텐서플로우 Saver (v1 -> v2 코드 업그레이드 7)
Saver
텐서플로우의 Saver 모듈은 모델과 파라미터를 저장하고 불러오는 기능을 제공한다.
모델을 학습시키고 나면 학습된 모델을 토대로 기능을 수행해야하므로 저장이 필요하다.
그리고 보통 머신러닝 학습을 진행하면 시간이 오래 소요되는데,
Saver 의 save 함수를 이용하면 학습 중간중간 특정 시점마다 모델 저장 또한 가능하다.
(학습 시간이 마냥 길어진다고 모델의 성능이 좋아지는 것은 아니기에
이렇게 중간 저장해둔 모델이 유용하게 쓰일 수 있다.)
모델이 저장되면 checkpoint 파일들과 meta graph 파일(.meta 확장자) 이 남게 되는데
checkpoint 파일에는 weights, biases, grandients 등의 정보가 저장되며
meta graph 파일에는 말 그대로 모델 학습에 사용된 연산 그래프 정보가 저장된다. (variables, operations 등)
이렇게 저장된 모델은 다시 불러올 수 있는데,
이 때 위에서 언급한 meta graph 파일은 tf.train.import_meta_graph 함수를 이용해서 불러오고
checkpoint 파일은 tf.train.Saver.restore 함수를 이용하여 불러올 수 있다.
모델을 저장할 때는 다음과 같이 사용한다.
tf.train.Saver().save(session, checkpoint)
# AttributeError: module 'tensorflow._api.v2.train' has no attribute 'Saver'
그런데 텐서플로우 2.0 버전 이상에서 실행하면 AttributeError 가 발생한다.
모델을 불러올 때도 마찬가지이다.
meta graph 파일을 import 한 뒤, checkpoint 파일을 restore 해오게 되는데
tf.train.import_meta_graph(checkpoint + '.meta').restore(session, checkpoint)
# AttributeError: module 'tensorflow._api.v2.train' has no attribute 'import_meta_graph'
마찬가지로 AttirubteError 가 발생한다.
텐서플로우 v2.0 에서도 모델 저장은 필수일텐데,
어떻게 버전업 코드에 맞게 변환하여 사용할 수 있을까?
1. tf.saved_model, tf.train.Checkpoint 모듈 활용
tf.saved_model 의 save, load 함수를 사용하면 모델 저장과 불러오기가 가능하다.
tf.saved_model.save(model, '/tmp/adder')
restored = tf.saved_model.load('/tmp/adder')
checkpoint 파일을 저장하고 관리하기 위해서는 다음과 같이 tf.train.Checkpoint 를 사용하면 된다.
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)
v2.0 부터 meta graph 는 생성/관리가 잘 지원되지 않는 듯했다.
이 경우 2번의 방법을 이용해서 기존의 코드를 그대로 사용해야 할 것 같다.
2. v1.0 코드 그대로 사용하기: compat 모듈 사용
지난 포스팅들에도 계속 언급했듯 compat 모듈을 사용하여 v2.0 기능을 끄도록 설정하고
기존 v1.0 코드를 그대로 사용할 수 있다.
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.train.Saver().save(session, checkpoint)
위 두 줄 처럼 코드 상단에 정의해두고 기존 코드를 그대로 사용하면 된다.