亚欧色一区w666天堂,色情一区二区三区免费看,少妇特黄A片一区二区三区,亚洲人成网站999久久久综合,国产av熟女一区二区三区

  • 發布文章
  • 消息中心
點贊
收藏
評論
分享
原創

pytorch中凍結訓練(上)

2024-11-15 09:17:46
6
0
import torch.nn as nn
import torch.optim as optim
import torch


# 定義一個簡單的網絡

class MyNet(nn.Module):
def __init__(self, num_class=5):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc1.weight = nn.Parameter(torch.ones((4,8), dtype=torch.float))
self.fc2 = nn.Linear(4, num_class)
self.fc2.weight = nn.Parameter(torch.ones((num_class,4), dtype=torch.float))

def forward(self, x):
return self.fc2(self.fc1(x))


model = MyNet()

loss_fn = nn.CrossEntropyLoss()

choice = 3
if choice == 1:
# 情況一:不凍結參數時
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 傳入的是所有的參數

if choice == 2:
# 情況二:采用方式一凍結fc1層時
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 優化器傳入的是所有的參數

if choice == 3:
# 情況三:采用方式二凍結fc1層時, 優化器只傳入fc2的參數
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)

if choice == 4:
# 情況4: 最優做法是將不更新的參數的requires_grad設置為False,同時不將該參數傳入optimizer
# 凍結fc1層的參數
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False

# 定義一個 filter ,只傳入requires_grad=True的模型參數
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)


# 訓練前的模型參數
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0, 5, [3]).long()
output = model(x)

loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 訓練后的模型參數
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

"""
結論
1. 最優寫法能夠節省顯存和提升速度:
2. 節省顯存:不將不更新的參數傳入optimizer
3. 提升速度:將不更新的參數的requires_grad設置為False,節省了計算這部分參數梯度的時間

# element_size返回單個元素的字節大小,nelement返回元素個數
import torch
a = torch.zeros([128, 128])
print(a.element_size() * a.nelement())

# pytorch查看模型的參數總量、占用顯存量以及flops
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
使用DNN_printer
"""
0條評論
作者已關閉評論
Top123
32文章數
3粉絲數
Top123
32 文章 | 3 粉絲
Top123
32文章數
3粉絲數
Top123
32 文章 | 3 粉絲
原創

pytorch中凍結訓練(上)

2024-11-15 09:17:46
6
0
import torch.nn as nn
import torch.optim as optim
import torch


# 定義一個簡單的網絡

class MyNet(nn.Module):
def __init__(self, num_class=5):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc1.weight = nn.Parameter(torch.ones((4,8), dtype=torch.float))
self.fc2 = nn.Linear(4, num_class)
self.fc2.weight = nn.Parameter(torch.ones((num_class,4), dtype=torch.float))

def forward(self, x):
return self.fc2(self.fc1(x))


model = MyNet()

loss_fn = nn.CrossEntropyLoss()

choice = 3
if choice == 1:
# 情況一:不凍結參數時
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 傳入的是所有的參數

if choice == 2:
# 情況二:采用方式一凍結fc1層時
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 優化器傳入的是所有的參數

if choice == 3:
# 情況三:采用方式二凍結fc1層時, 優化器只傳入fc2的參數
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)

if choice == 4:
# 情況4: 最優做法是將不更新的參數的requires_grad設置為False,同時不將該參數傳入optimizer
# 凍結fc1層的參數
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False

# 定義一個 filter ,只傳入requires_grad=True的模型參數
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)


# 訓練前的模型參數
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0, 5, [3]).long()
output = model(x)

loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 訓練后的模型參數
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

"""
結論
1. 最優寫法能夠節省顯存和提升速度:
2. 節省顯存:不將不更新的參數傳入optimizer
3. 提升速度:將不更新的參數的requires_grad設置為False,節省了計算這部分參數梯度的時間

# element_size返回單個元素的字節大小,nelement返回元素個數
import torch
a = torch.zeros([128, 128])
print(a.element_size() * a.nelement())

# pytorch查看模型的參數總量、占用顯存量以及flops
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
使用DNN_printer
"""
文章來自個人專欄
文章 | 訂閱
0條評論
作者已關閉評論
作者已關閉評論
0
0