Saver
텐서플로우의 Saver 모듈은 모델과 파라미터를 저장하고 불러오는 기능을 제공한다.
모델을 학습시키고 나면 학습된 모델을 토대로 기능을 수행해야하므로 저장이 필요하다.
그리고 보통 머신러닝 학습을 진행하면 시간이 오래 소요되는데,
Saver 의 save 함수를 이용하면 학습 중간중간 특정 시점마다 모델 저장 또한 가능하다.
(학습 시간이 마냥 길어진다고 모델의 성능이 좋아지는 것은 아니기에
이렇게 중간 저장해둔 모델이 유용하게 쓰일 수 있다.)
모델이 저장되면 checkpoint 파일들과 meta graph 파일(.meta 확장자) 이 남게 되는데
checkpoint 파일에는 weights, biases, grandients 등의 정보가 저장되며
meta graph 파일에는 말 그대로 모델 학습에 사용된 연산 그래프 정보가 저장된다. (variables, operations 등)
이렇게 저장된 모델은 다시 불러올 수 있는데,
이 때 위에서 언급한 meta graph 파일은 tf.train.import_meta_graph 함수를 이용해서 불러오고
checkpoint 파일은 tf.train.Saver.restore 함수를 이용하여 불러올 수 있다.
모델을 저장할 때는 다음과 같이 사용한다.
tf.train.Saver().save(session, checkpoint)
# AttributeError: module 'tensorflow._api.v2.train' has no attribute 'Saver'
그런데 텐서플로우 2.0 버전 이상에서 실행하면 AttributeError 가 발생한다.
모델을 불러올 때도 마찬가지이다.
meta graph 파일을 import 한 뒤, checkpoint 파일을 restore 해오게 되는데
tf.train.import_meta_graph(checkpoint + '.meta').restore(session, checkpoint)
# AttributeError: module 'tensorflow._api.v2.train' has no attribute 'import_meta_graph'
마찬가지로 AttirubteError 가 발생한다.
텐서플로우 v2.0 에서도 모델 저장은 필수일텐데,
어떻게 버전업 코드에 맞게 변환하여 사용할 수 있을까?
1. tf.saved_model, tf.train.Checkpoint 모듈 활용
tf.saved_model 의 save, load 함수를 사용하면 모델 저장과 불러오기가 가능하다.
tf.saved_model.save(model, '/tmp/adder')
restored = tf.saved_model.load('/tmp/adder')
checkpoint 파일을 저장하고 관리하기 위해서는 다음과 같이 tf.train.Checkpoint 를 사용하면 된다.
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)
v2.0 부터 meta graph 는 생성/관리가 잘 지원되지 않는 듯했다.
이 경우 2번의 방법을 이용해서 기존의 코드를 그대로 사용해야 할 것 같다.
2. v1.0 코드 그대로 사용하기: compat 모듈 사용
지난 포스팅들에도 계속 언급했듯 compat 모듈을 사용하여 v2.0 기능을 끄도록 설정하고
기존 v1.0 코드를 그대로 사용할 수 있다.
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.train.Saver().save(session, checkpoint)
위 두 줄 처럼 코드 상단에 정의해두고 기존 코드를 그대로 사용하면 된다.
':: ai > tensorflow' 카테고리의 다른 글
tensorflow :: 텐서플로우 ConfigProto (v1 -> v2 코드 업그레이드 8) (0) | 2022.06.14 |
---|---|
tensorflow :: 텐서플로우 Optimizer 개념 및 코드 업그레이드 (v1 -> v2 코드 업그레이드 6) (0) | 2022.04.17 |
tensorflow :: 텐서플로우 losses 의 regularization loss (v1 -> v2 코드 업그레이드 5) (0) | 2022.04.12 |
tensorflow :: 텐서플로우 layers (v1 -> v2 코드 업그레이드 4) (0) | 2022.03.31 |
tensorflow :: 텐서플로우 contrib (v1 -> v2 코드 업그레이드 3) (0) | 2022.03.30 |