基于Python深度学习框架Pytorch的图像分割代码示例
以上代码定义了一个UNet模型,用于图像分割任务。模型训练过程中使用了数据增强技术,并定义了训练集和测试集。在训练完成后,使用torch.save()函数将模型保存为.pt文件。
·
以下是一个基于Pytorch的图像分割代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
# 定义数据集
class ImageSegmentationDataset(Dataset):
def __init__(self, image_dir, label_dir, transform=None):
self.image_dir = image_dir
self.label_dir = label_dir
self.transform = transform
self.images = os.listdir(image_dir)
self.labels = os.listdir(label_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = os.path.join(self.image_dir, self.images[idx])
label_path = os.path.join(self.label_dir, self.labels[idx])
image = Image.open(image_path)
label = Image.open(label_path)
if self.transform:
image = self.transform(image)
label = self.transform(label)
return image, label
# 定义模型
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
self.conv5 = nn.Conv2d(512, 1024, 3, padding=1)
self.max_pool = nn.MaxPool2d(2)
self.up_conv1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.up_conv2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.up_conv3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up_conv4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.out_conv = nn.Conv2d(64, 2, 1)
def forward(self, x):
x1 = nn.functional.relu(self.conv1(x))
x2 = nn.functional.relu(self.conv2(self.max_pool(x1)))
x3 = nn.functional.relu(self.conv3(self.max_pool(x2)))
x4 = nn.functional.relu(self.conv4(self.max_pool(x3)))
x5 = nn.functional.relu(self.conv5(self.max_pool(x4)))
y = self.up_conv1(x5)
y = torch.cat([y, x4], dim=1)
y = nn.functional.relu(self.up_conv2(y))
y = torch.cat([y, x3], dim=1)
y = nn.functional.relu(self.up_conv3(y))
y = torch.cat([y, x2], dim=1)
y = nn.functional.relu(self.up_conv4(y))
y = self.out_conv(y)
return y
# 数据增强
transform = transforms.Compose([
transforms.Resize(size=(256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
])
# 定义训练集和测试集
train_image_dir = 'train/images/'
train_label_dir = 'train/labels/'
train_dataset = ImageSegmentationDataset(train_image_dir, train_label_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_image_dir = 'test/images/'
test_label_dir = 'test/labels/'
test_dataset = ImageSegmentationDataset(test_image_dir, test_label_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
# 实例化模型和优化器
model = UNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
model.train()
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels[:,0,:,:])
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] train loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))
# 测试模型
model.eval()
test_loss = 0.0
with torch.no_grad():
for i, data in enumerate(test_loader):
inputs, labels = data
outputs = model(inputs)
loss = criterion(outputs, labels[:,0,:,:])
test_loss += loss.item()
print('test loss: %.3f' % (test_loss / len(test_loader)))
# 保存模型
torch.save(model.state_dict(), 'model.pt')
以上代码定义了一个UNet模型,用于图像分割任务。模型训练过程中使用了数据增强技术,并定义了训练集和测试集。在训练完成后,使用torch.save()函数将模型保存为.pt文件。
更多推荐
所有评论(0)