오늘은 앞서 배운 것을 기반으로 좀 더 간결한 형태(추상화된)로 모델을 학습시킬 수 있는 Lightning이라는 라이브러리를 소개하려고 한다.
간단한 라이트닝 AI 소개
세계에서 가장 직관적이고 사용하기 쉬우며 가장 빠른 AI 작업 플랫폼을 제공함으로써 "과학은 여러분이 엔지니어링은 저희가"라는 핵심정신을 가지고 있는 회사인데 제공하는 제품 중에는 AI Studio라고 해서 AI 모델 개발에 필요한 여러 도구들을 하나의 일관된 환경으로 통합하는 프로그램을 제공하고 있고, 파이토치에서 우리가 일일히 컨트롤해줘야하는 부분들을 추상화시켜서 최소한의 필요한 코드만 작성해서 훈련을 시킬 수 있도록 하는 lightning pytorch도 제공하고 있다.
장점
1. 코드의 추상화 및 하드웨어 호출 자동화
- 기존 파이토치는 모델, optimizer, training loop를 전부 따로따로 구현한다.
- 이를 Lightning Module 클래스 안에 모든 것을 한번에 구현하게 되어있다.
- .to(device)를 안해도 됨
2. 다양한 콜백 함수와 로깅
- 초기 학습률 자동으로 찾아주거나 조기 종료의 기능을 한줄의 코드로 구현할 수 있게 해줌
3. 16-bit precision
- 32비트를 16비트로 줄여서 메모리 사용량을 줄이고자 한 아이디어인데 이걸 제공해줌
위의 이 회사의 목표인 "과학은 여러분이, 엔지니어링은 저희가"와 같이 문제를 해결하는 모델을 만드는 것이 가장 중요하고 엔지니어링은 부수적인 목적이기 때문에 그 목적에 맞춰진 라이브러리라고 볼 수 있다.
일반적인 파이토치와의 비교
우리가 모델을 만들 때 nn.Module을 상속받아서 기본적으로 neural network를 만드는데 필요한 메서드들을 사용할 수 있게 해주는데 lightning은 이 nn.Module을 상속받아서 추가적으로 추상화할 수 있는 메소드를 만들어서 이걸 상속받아서 활용한다.
1. Data : 일반 파이토치와 동일하거나 LightningDataModule을 사용하는 방법도 있지만 사실 이게 큰 장점이 있는지는 모르겠다.
- 내부 문서에 따르면 이 데이터 처리까지도 Model 코드에 prepare_data, setup(데이터 분리), train/val/test dataloader 함수로 모델에 넣는 걸 권장하는데, 이를 분리해서 데이터에 의존하지 않는 독립적인 모델을 만들 때 이런 LightningDataModule을 사용한다.
- 물론 나는 기존 pytorch 코드를 그대로 바꾸기 위해서 prepare_data, setup, dataloader들을 사용하지는 않는다.
- 관련 링크 : https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/datamodules.html#Using-DataModules
2. Model : 모델에서 한번에 Training, Evaluation, Test에 필요한 코드 정의(단, 코드 줄 자체는 매우 간소화)
- model.train(), model.eval(), to(device), zero_grad(), loss_backward(), optimizer.step()과 같은 반복적인 코드는 Lightning에서 자체 처리해줘서 생략 가능
- 자체 로깅 기능 제공 및 배치 처리로 tqdm 및 for문 처리도 생략 가능
3. 학습 : Trainer를 활용하면 training_loop에서 변수 초기화 & Early Stopping과 같은 처리도 간단하게 라이브러리 적용으로 처리 가능
4. 하드웨어 학습 : 간편하게 GPU를 다룰 수 있게 해준다.
실제 코드 작성
위의 파이토치 코드와의 비교에서 보다시피 크게 Model 정의 , Trainer(학습)만 작성해주면 나머지는 알아서 작동한다.
1. import
오늘 실습할 코드는 기본적으로 MNIST 데이터셋을 활용해, CNN 모델을 만들려고 한다. 라이트닝에는 여러 라이브러리가 있는데 크게 5개 정도를 사용한다.
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
import torchmetrics
#from pytorch_lightning import LightningDataModule <- 데이터까지 이걸 활용하려면 추가하면 됨
- LightningModule : 우리가 항상 모델을 만들 때 nn.Module을 상속받는 것처럼 여기서는 LightningModule을 모델 클래스 생성시 상속받는다.
- Trainer : training_loop 역할을 하는 라이브러리로 몇가지 인자만 넣어주면 실행 즉시 학습이 시작되며 친절하게 학습 결과를 알려준다.
- LearningRateMonitor : lr을 스케줄링하게 되면 각 학습률이 언제일 때 loss나 accuracy가 어땠는지를 볼 수 있도록 제공해준다.
- EarlyStopping : 우리가 실제 여러 인자를 정의하고 코드를 짜서 구현한 EarlyStopping을 라이브러리 호출로 바로 제공해준다.
- CSVLogger : 에폭, 횟수당 loss, accuracy 등을 CSV 형태로 기록해준다.(별도로 tqdm이나 print 함수를 호출하지 않아도 됨)
2. 모델 정의
기존 파이토치 코드와 동일하게 init, forward에서는 동일한 코드를 사용한다. 대신 training, evaluation 함수 코드가 training_step, validation_step, test_step의 형태로 모델 안에서 정의해준다.

init, forward는 사실 크게 차이가 없어서 후반부에 전체 코드에서 공유드리는 걸로 하고 기존의 파이토치에서의 training 함수와 라이트닝에서의 training_step의 코드를 비교해보면 아래와 같다.
손수 optimizer나 loss에 처리해줬던 것들을 라이트닝이 모두 제공함으로써 코드에서는 핵심인 outputs, loss, acc만 계산하고 결과만 log로 적용하는 코드만 정리해준다.

validation, test, predict는 모두 코드 자체는 동일해서 함수명만 바꿔주면 된다. predict는 굳이 Log를 할 필요가 없어서 저기에서 self.log만 삭제하면 된다.

configure_optimizers는 optimizer와 scheduler를 정의하는데 기존과 동일하게 1줄 정도로 정의해주는 것만으로 동작한다.

3. Trainer : Training_loop 역할을 대체
모델의 학습 epoch이나 batch 등의 상태뿐만 아니라, 모델을 저장해 로그를 생성하는 부분까지 담당합니다. 실제로 코드에서는 pl.Trainer()라고 정의하면 끝입니다. 실제로 아래 그림처럼 pytorch와 비교했을 때 매우 간단하게 설정할 수 있는 것을 볼 수 있다.

모델은 학습 과정에서 지속적으로 구조를 개선하고 반복하기 때문에 모니터링하는게 필요한데 이 때 여러 툴을 제공한다.
- LearningRateMonitor : lr을 스케줄링하게 되면 각 학습률이 언제일 때 loss나 accuracy가 어땠는지를 볼 수 있도록 제공해준다.
- EarlyStopping : 우리가 실제 여러 인자를 정의하고 코드를 짜서 구현한 EarlyStopping을 라이브러리 호출로 바로 제공해준다.
- CSVLogger : 에폭, 횟수당 loss, accuracy 등을 CSV 형태로 기록해준다.(별도로 tqdm이나 print 함수를 호출하지 않아도 됨)
trainer는 이렇게 한줄로 작성해주는 것만으로도 바로 설정 가능하고, trainer.fit(model, train_dataloader, val_dataloader)으로 동작한다.
trainer = Trainer(max_epochs = 100, accelerator = 'auto',
callbacks = [early_stopping, lr_monitor], logger = csv_logger)
결과
실제 코드를 돌리면 아래와 같이 하드웨어 사양, 모델 개요(accuracy, loss function, 파라미터), metric이 나온다.

추가로 실행하면 CSVLoader를 통해 csv 파일이 저장되고 내용은 아래와 같이 저장된다.

개인 생각
실제 스타트업에서 빠르게 제품을 만들기 위해서는 이걸 사용하는게 좋은 것 같은데 진짜 리서쳐가 되고 싶어하는 학문적인 사람은 직접 파이토치로 모든 코드를 작성하는게 좋겠다는 생각이 들었다. 나는 그런 사람이기 보다는 빠르게 제품을 만들고 싶은 사람이라서 이런 툴들이 계속 나오는 건 좋은 것 같다.
다만, 여기서 추상화시키거나 csvlogger와 같은 기능들은 결국 내가 직접 구현하지 않고 한줄의 코드로 적용한 것이기 때문에 그 결과가 어떤 것을 의미하는지에 대해서는 추가적인 공부와 적응은 필요한 것 같다.
전체 코드
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
import torchmetrics
class Classifier(LightningModule):
def __init__(self, num_classes, dropout_ratio, lr = 0.001):
super().__init__()
self.learning_rate = lr
self.accuracy = torchmetrics.Accuracy(task = 'multiclass', num_classes = num_classes)
self.criterion = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.layer = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 5),
nn.ReLU(),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2),
nn.Dropout(self.dropout_ratio),
nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2),
nn.Dropout(self.dropout_ratio)
)
self.fc_layer = nn.Linear(576, self.num_classes)
self.classifier = nn.LogSoftmax(dim = 1)
def forward(self, x): #동일함.
out = self.layer(x)
out = out.view(out.size(0), -1)
pred = self.fc_layer(out)
return pred
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr = self.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.9)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
acc = self.accuracy(outputs, labels)
self.log("train_loss", loss, on_step = False, on_epoch = True, logger = True)
self.log("train_acc", acc, on_step = False, on_epoch = True, logger = True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
_, preds = torch.max(outputs, dim = 1)
acc = self.accuracy(preds, labels)
loss = self.criterion(outputs, labels)
self.log(f"valid_loss", loss, on_step = False, on_epoch = True, logger = True)
self.log(f"valid_acc", loss, on_step = False, on_epoch = True, logger = True)
def test_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
_, preds = torch.max(outputs, dim = 1)
acc = self.accuracy(preds, labels)
loss = self.criterion(outputs, labels)
self.log(f"test_loss", loss, on_step = False, on_epoch = True)
self.log(f"test_acc", loss, on_step = False, on_epoch = True)
def predict_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
_, preds = torch.max(outputs, dim = 1)
return preds
# from lightning.pytorch.callbacks import LearningRateMonitor, EarlyStopping
# from lightning.pytorch.loggers import CSVLogger
model = Classifier(num_classes = 10, dropout_ratio = 0.2
)
lr_monitor = LearningRateMonitor(logging_interval = "epoch")
early_stopping = EarlyStopping(monitor = 'valid_loss', mode = 'min')
csv_logger = CSVLogger(save_dir = "./csv_logger", name = 'test')
trainer = Trainer(max_epochs = 100, accelerator = 'auto',
callbacks = [early_stopping, lr_monitor], logger = csv_logger)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(model, test_dataloader)
공부하면서 든 질문들과 답
1. lightning을 불러오는데 메뉴얼과 공부했던 자료에서 서로 다르게 불러오는데 그 차이가 무엇일까?
from pytorch_lightning과 from lightning.pytorch는 PyTorch Lightning 라이브러리를 임포트하는 두 가지 다른 방식을 나타내는 것처럼 보일 수 있지만, 실제로는 중요한 차이가 있습니다.
- from pytorch_lightning: 이 구문은 PyTorch Lightning 라이브러리를 임포트하는 표준적인 방법입니다. PyTorch Lightning은 PyTorch를 위한 고수준 인터페이스를 제공하며, 더 간결하고 구조화된 방식으로 딥러닝 모델을 훈련하고 실험하는 데 사용됩니다. 여기서 pytorch_lightning은 라이브러리의 이름입니다.
- from lightning.pytorch: 이 구문은 일반적으로 PyTorch Lightning 라이브러리와 관련이 없습니다. 만약 이러한 구문이 사용된다면, lightning이라는 이름의 다른 패키지나 모듈 내에 있는 pytorch라는 하위 모듈을 가리키게 됩니다. 이는 표준 PyTorch Lightning 설치와는 관련이 없으며, 특정 프로젝트나 코드베이스에서 정의된 모듈일 수 있습니다.(하지만 실제로 정식 Docs에서는 이렇게 사용하라고 되어있음..)
따라서, PyTorch Lightning 라이브러리를 사용하고자 한다면, from pytorch_lightning을 사용해야 합니다. from lightning.pytorch는 특정 프로젝트에서 정의된 다른 것을 가리키는 것일 수 있으므로, 이를 사용하기 전에 해당 코드나 라이브러리의 문서를 확인하는 것이 중요합니다.
2. 아래 코드에서 self(images)라는 표현이 있는데 이게 어떤 걸 의미하는지?
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
이 코드의 맥락에서, Classifier 클래스는 LightningModule을 상속받고 있으며, LightningModule은 torch.nn.Module을 상속받는다. PyTorch에서 nn.Module 클래스는 __call__ 메서드를 구현하고 있어, 이를 상속받은 모든 클래스(여기서는 Classifier) 인스턴스는 함수처럼 호출될 수 있습니다. outputs = self(images) 구문은 Classifier 클래스의 인스턴스(self)를 사용하여 images를 입력으로 받아, forward 메서드를 호출한다. 이 때, forward 메서드는 모델의 신경망을 통과하는 입력 데이터의 순전파 과정을 정의합니다. 결과적으로, 이 코드는 입력 이미지 데이터를 모델에 통과시켜 결과를 출력하는 과정을 나타냅니다.
3. 콜백 함수란?
컴퓨터 프로그래밍에서 "콜백(callback)"은 일반적으로 다른 코드의 인자로 전달되는 실행 가능한 코드 조각을 의미합니다. 콜백은 특정 이벤트가 발생하거나 특정 조건이 만족될 때 호출됩니다. 콜백은 함수 포인터, 함수 참조, 또는 람다 표현식과 같은 형태로 사용될 수 있습니다.
콜백의 주요 사용 사례는 다음과 같습니다:
- 이벤트 처리: 사용자 인터페이스에서 버튼 클릭, 키보드 입력과 같은 이벤트가 발생했을 때 실행할 함수를 정의하는 데 콜백이 사용됩니다.
- 비동기 처리: 네트워크 요청, 파일 입출력 등 비동기적인 작업이 완료되었을 때 실행할 작업을 정의하는 데 콜백이 사용됩니다. 예를 들어, 데이터를 서버에서 가져온 후 처리할 작업을 콜백으로 정의할 수 있습니다.
- 고차 함수: 함수형 프로그래밍에서, 함수는 다른 함수를 인자로 받거나 함수를 결과로 반환할 수 있습니다. 이러한 고차 함수에서 콜백 함수가 사용됩니다.
- 프레임워크와 라이브러리: 많은 프레임워크와 라이브러리에서 사용자가 정의한 콜백 함수를 사용하여 사용자 정의 동작을 수행합니다. 예를 들어, PyTorch Lightning에서는 학습 과정 중 특정 이벤트(에포크 시작, 배치 처리 완료 등)에 사용자가 정의한 콜백 함수를 호출할 수 있습니다.
콜백은 프로그램의 유연성을 증가시키고 사용자 정의 동작을 쉽게 통합할 수 있게 해주는 강력한 도구입니다. 그러나 콜백을 사용할 때는 콜백 지옥(callback hell)이나 복잡한 의존성 관리와 같은 문제에 주의해야 합니다.
참고 링크
1. https://lightning.ai/docs/pytorch/stable/starter/introduction.html
2. https://baeseongsu.github.io/posts/pytorch-lightning-introduction/
'Machine Learning > 개념 정리' 카테고리의 다른 글
[파이토치] 0. 에러 아카이브 (0) | 2024.01.02 |
---|---|
[파이토치] 6. TensorBoard와 WandB (2) | 2024.01.02 |
[파이토치] 4. 여러 모듈 사용 (2) | 2023.12.29 |
[파이토치] 2. 데이터 처리 (0) | 2023.12.29 |
[파이토치] 1. 모델 선언 (1) | 2023.12.29 |