以下是一个基于Pytorch的图像分割代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# 定义数据集
class ImageSegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.labels = os.listdir(label_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.labels[idx])
        
        image = Image.open(image_path)
        label = Image.open(label_path)

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

# 定义模型
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv5 = nn.Conv2d(512, 1024, 3, padding=1)
        self.max_pool = nn.MaxPool2d(2)
        self.up_conv1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.up_conv2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up_conv3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_conv4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.out_conv = nn.Conv2d(64, 2, 1)

    def forward(self, x):
        x1 = nn.functional.relu(self.conv1(x))
        x2 = nn.functional.relu(self.conv2(self.max_pool(x1)))
        x3 = nn.functional.relu(self.conv3(self.max_pool(x2)))
        x4 = nn.functional.relu(self.conv4(self.max_pool(x3)))
        x5 = nn.functional.relu(self.conv5(self.max_pool(x4)))
        y = self.up_conv1(x5)
        y = torch.cat([y, x4], dim=1)
        y = nn.functional.relu(self.up_conv2(y))
        y = torch.cat([y, x3], dim=1)
        y = nn.functional.relu(self.up_conv3(y))
        y = torch.cat([y, x2], dim=1)
        y = nn.functional.relu(self.up_conv4(y))
        y = self.out_conv(y)
        return y

# 数据增强
transform = transforms.Compose([
    transforms.Resize(size=(256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])

# 定义训练集和测试集
train_image_dir = 'train/images/'
train_label_dir = 'train/labels/'
train_dataset = ImageSegmentationDataset(train_image_dir, train_label_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

test_image_dir = 'test/images/'
test_label_dir = 'test/labels/'
test_dataset = ImageSegmentationDataset(test_image_dir, test_label_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# 实例化模型和优化器
model = UNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels[:,0,:,:])
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('[%d] train loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))

# 测试模型
model.eval()
test_loss = 0.0
with torch.no_grad():
    for i, data in enumerate(test_loader):
        inputs, labels = data
        outputs = model(inputs)
        loss = criterion(outputs, labels[:,0,:,:])
        test_loss += loss.item()
print('test loss: %.3f' % (test_loss / len(test_loader)))

# 保存模型
torch.save(model.state_dict(), 'model.pt')

以上代码定义了一个UNet模型,用于图像分割任务。模型训练过程中使用了数据增强技术,并定义了训练集和测试集。在训练完成后,使用torch.save()函数将模型保存为.pt文件。

Logo

更多推荐