지금까지 데이터를 불러오고-> layer을 쌓고(합치고) -> 그걸 model 에 넣는 작업까지 학습했습니다.
이번에는 model을 학습시키는 것을 배워 보겠습니다.
학습과정은,
데이터가 모델에 들어가서 ->
모델이 예측을 하고 ->
그 예측값과 정답을 비교한 뒤 ->
loss function으로 얼마나 틀렸는지를 계산합니다. ->
그를 바탕으로 Optimization을 하고, ->
그걸 model에 업로드 합니다. ->
이렇게 model이 업데이트 되었으면, 다시 데이터가 들어가고를 반복합니다.
1. Optimization
1) 모델을 optimization하기 전 설정해야 할 것들
1. Loss function
#1-1. binary classification일 때
tf.keras.losses.binary_crossentropy # 또는 loss = 'binary_crossentropy' 라고 해줘도 됩니다.
#1-2. categorical(multiple) Classification일 때
tf.keras.lossses.categorical_crossentropy # 또는 loss = 'categorical_crossentropy'라고 해줘도 됩니다.
#2-1. onehot encoding을 진행 하지 않은 데이터인 경우
tf.keras.losses.sparse_categorical_crossentropy
#2-2. onehot encoding을 진행 한 데이터인 경우
tf.keras.lossses.categorical_crossentropy
#MNIST의 경우, categorical이고, onehot도 안 해줬다면,
loss_func = tf.keras.losses.sparse_categorical_crossentropy
2. Optimization
#방법1. Sgd
tf.keras.optimizers.SGD()
#방법2. rmsProp
tf.keras.optimizers.RMSprop()
#방법3. Adam
optm = tf.keras.optimizers.Adam() #이걸로 정했다면,
3. Metrics
모델을 평가하는 방법입니다.
#전체 문제에서 몇개를 맞추었나
metrics = ['accuracy']
#또는
tf.keras.metrics.Accuracy()
tf.keras.metrics.Precision()
tf.keras.metrics.Recall()
2) 다 설정해 주었다면, model에 compile해 줍니다.
model.compile(optimizer = optm, loss = loss_func, metrics = metrics)
2. Training
1) 모델을 training하기 전 설정해야 할 것들 - 학습용 Hyperparameter 설정
1. num_epochs
num_epochs =10
2. batch_size
batch_size = 32
2) model fit
모델을 완성했다면 fit 메서드로 트레이닝을 시작합니다.
model.fit(X_train, y_train, batch_size = batch_size, shuffle = True, epochs = num_epochs) #학습을 할 때는 overfitting, bias를 막기 위해 shuffle이 필요합니다.
reference
https://datascienceschool.net/view-notebook/51e147088d474fe1bf32e394394eaea7/
'Data Science > Tensorflow2.0, Pytorch' 카테고리의 다른 글
8. Tensorflow2.0 Optimization & Training (Expert) (0) | 2020.03.27 |
---|---|
7. Tensorflow2.0 Optimization & Training (Keras) (0) | 2020.03.27 |
5. Tensorflow2.0 Build Model (0) | 2020.03.25 |
4. Tensorflow2.0 Layer Explaination (0) | 2020.03.25 |
3. Tensorflow2.0 (y_train) Data Preprocess (using MNIST) (0) | 2020.03.25 |