Optimizer
머신러닝 모델 학습에서 optimizer 는 꼭 필요한 개념이다.
기계는 주어진 연산에 따라 데이터를 학습시키면서 매 단계마다 손실값(loss)을 계산한다.
그리고 이 손실값을 줄여나가기 위해 변수(variable)들을 약간씩 조정하면서 다시 학습을 진행해 나간다.
손실값이 커진다는 것은 쉽게 말하자면
'주어진 학습데이터에 잘 맞지 않는 모델을 만들어가고 있다'는 건데,
그래서 기계가 중간중간 손실값을 체크하면서
이를 줄여나갈 수 있는 방향으로 최적화(optimization)를 진행하는 것이다.
이 최적화(optimization)를 진행하는 것이 옵티마이저(Optimizer) 이고
손실 함수를 통해 얻은 손실값으로부터 모델을 업데이트하는 방식을 의미한다.
옵티마이저의 종류는 다양한데 (SGD, Adam, RMSprop 등)
각각 손실값을 최적화해나가는 방식이 조금씩 다르다.
개인적으로 가장 많이 보고 또 사용한 옵티마이저는 Adam Optimizer 이다.
텐서플로우 v1 코드에서 adam optimizer 는 다음과 같이 선언하여 적용했다.
- learning rate 를 지정하여 AdamOptimizer 선언
- minimize 함수 호출
minimize 함수 호출 시 인자에 손실값을 부여하면
모델학습 시 손실값을 최소화하는 방향으로 adam optimizer 가 적용된다.
import tensorflow as tf
tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
AttributeError: module 'tensorflow._api.v2.train' has no attribute 'AdamOptimizer'
그런데 v2.0 으로 넘어가면서 위와 같이 에러가 발생한다.
이는 텐서플로우 v2.x 에서 AdamOptimizer 를 기존과 동일하게 제공하지 않기 때문이다.
1. tf.optimizers 활용
Optimizer 같은 중요한 개념은 v2.0 에서도 당연히 남아있다.
다만 텐서플로우 패키지 경로들에 변경이 좀 있는 듯하다.
optimizers 모듈을 활용하면 기존과 동일하게 사용할 수 있다.
train_op = tf.optimizers.Adam(learning_rate=0.001).minimize(loss, var_list)
minimize 함수 또한 존재한다.
그러나 기존과 다른 점은 인자에 'var_list' 가 추가되었다는 점이다.
var_list: list or tuple of `Variable` objects to update to minimize
`loss`, or a callable returning the list or tuple of `Variable` objects.
Use callable when the variable list would otherwise be incomplete before
`minimize` since the variables are created at the first time `loss` is
called.
2. compat 사용
지난 포스팅들에도 계속 언급했듯 compat 모듈을 사용하여 v2.0 기능을 끄도록 설정하여
기존의 v1.0 코드를 수정 없이 그대로 사용하는 방법도 있다.
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
위 두 줄 처럼 코드 상단에 정의해두고 기존 코드를 그대로 사용하면 된다.
':: ai > tensorflow' 카테고리의 다른 글
tensorflow :: 텐서플로우 ConfigProto (v1 -> v2 코드 업그레이드 8) (0) | 2022.06.14 |
---|---|
tensorflow :: 텐서플로우 Saver (v1 -> v2 코드 업그레이드 7) (0) | 2022.06.07 |
tensorflow :: 텐서플로우 losses 의 regularization loss (v1 -> v2 코드 업그레이드 5) (0) | 2022.04.12 |
tensorflow :: 텐서플로우 layers (v1 -> v2 코드 업그레이드 4) (0) | 2022.03.31 |
tensorflow :: 텐서플로우 contrib (v1 -> v2 코드 업그레이드 3) (0) | 2022.03.30 |