[모두를 위한 딥러닝] PyTorch 2장 Linear Regression

2023. 1. 12. 17:04DL Study

출처 : 모두를 위한 딥러닝 (https://youtu.be/kyjBMuNM1DI)

라이브러리 import

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Random seed 고정

torch.manual_seed(1)

x_train 과 y_train 정의

x_train = torch.FloatTensor([[1], [2], [3]])
y_train = torch.FloatTensor([[1], [2], [3]])

W,b 모델 초기화 

W = torch.zeros(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)

or
#Multivariate
W = torch.zeros((3, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

 

 

optimizer 설정

optimizer = optim.SGD([W, b], lr=0.01)    #optimizer 정의

optimizer.zero_grad()    #gradient를 0 으로 초기화
cost.backward()    #gradient 계산
optimizer.step()    #gradient decent 개선하는것

 

Multivariate Linear Regression

 

Hypothsis 정의

hypothesis = x_train.matmul(W) + b    #행렬연산을 내장함수로 간편하게

W,b 모델 초기화 및 정의 nn.Module

class MultivariateLinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 1)

    def forward(self, x):
        return self.linear(x)       
.
.
.
prediction = model(x_train)

Cost 계산 F.mse_loss

cost = F.mse_loss(prediction, y_train)

PyTorch Dataset

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        self.x_data = [[73, 80, 75],
                       [93, 88, 93],
                       [89, 91, 90],
                       [96, 98, 100],
                       [73, 66, 70]]
        self.y_data = [[152], [185], [180], [196], [142]]

    def __len__(self):    #이 데이터셋의 총 데이터수
        return len(self.x_data)

    def __getitem__(self, idx):    #어떠한 인덱스를 받았을때, 상응하는 입출력 데이터 반환
        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])

        return x,y
.
.
.
dataset = CustomDataset

PyTorch DataLoader

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataser,
    batch_size = 2,   #각 베치 사이즈 통상적으로 2의 배수
    sheffle = True,    #Epoch 마다 데이터셋을 섞어서, 데이터가 학습되는 순서를 바꾼다.
)