[모두를 위한 딥러닝] PyTorch 2장 Linear Regression
2023. 1. 12. 17:04ㆍDL Study
출처 : 모두를 위한 딥러닝 (https://youtu.be/kyjBMuNM1DI)
라이브러리 import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
Random seed 고정
torch.manual_seed(1)
x_train 과 y_train 정의
x_train = torch.FloatTensor([[1], [2], [3]])
y_train = torch.FloatTensor([[1], [2], [3]])
W,b 모델 초기화
W = torch.zeros(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
or
#Multivariate
W = torch.zeros((3, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
optimizer 설정
optimizer = optim.SGD([W, b], lr=0.01) #optimizer 정의
optimizer.zero_grad() #gradient를 0 으로 초기화
cost.backward() #gradient 계산
optimizer.step() #gradient decent 개선하는것
Multivariate Linear Regression
Hypothsis 정의
hypothesis = x_train.matmul(W) + b #행렬연산을 내장함수로 간편하게
W,b 모델 초기화 및 정의 nn.Module
class MultivariateLinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 1)
def forward(self, x):
return self.linear(x)
.
.
.
prediction = model(x_train)
Cost 계산 F.mse_loss
cost = F.mse_loss(prediction, y_train)
PyTorch Dataset
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
self.x_data = [[73, 80, 75],
[93, 88, 93],
[89, 91, 90],
[96, 98, 100],
[73, 66, 70]]
self.y_data = [[152], [185], [180], [196], [142]]
def __len__(self): #이 데이터셋의 총 데이터수
return len(self.x_data)
def __getitem__(self, idx): #어떠한 인덱스를 받았을때, 상응하는 입출력 데이터 반환
x = torch.FloatTensor(self.x_data[idx])
y = torch.FloatTensor(self.y_data[idx])
return x,y
.
.
.
dataset = CustomDataset
PyTorch DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataser,
batch_size = 2, #각 베치 사이즈 통상적으로 2의 배수
sheffle = True, #Epoch 마다 데이터셋을 섞어서, 데이터가 학습되는 순서를 바꾼다.
)
'DL Study' 카테고리의 다른 글
[파이썬에서 살아남는법 제 6장] range 보다는 enumerate 를 사용하라 (0) | 2023.01.12 |
---|---|
[파이썬에서 살아남는법 제 5장] 인덱스를 사용하는 대신 대입을 사용해 데이터를 언패킹 하라 (0) | 2023.01.12 |
[파이썬에서 살아남는법 제 4장] 복잡한 식을 쓰는 대신 도우미 함수를 작성하라 (0) | 2023.01.11 |
[파이썬에서 살아남는법 제 3장] C 스타일 형식 문자열을 str.format 과 쓰기보다는 f-문자열을 통한 인풀레이션을 사용하라 (0) | 2023.01.11 |
[파이썬에서 살아남는법 제 2장] bytes 와 str 차이를 알아두라 (0) | 2023.01.10 |