LibTorch实战六:U2-Net实战部署<三>

摘要:
目前,它是用于重要目标检测的最大的常用训练数据集。ECSSContainer1000具有复杂的结构,其中许多包含大型前景对象。通常,在实例分割中,有不止一个类别:a、b、c等等。IOU将单独计算,然后取平均值以获得U2 Net中的MIOU(注意:我们这里不是实例分割!
  •  导读
  • 一、数据标注
  • 二、模型评价
  • 三、源码解读
  • 四、Libtorch部署
  • 五、性能分析 
  • 六、问题记录

导读

U2-Net模型分为两种:

  • U2NET---173.6 MB (参数量:4千万)
  • U2NEP---4.7 MB    (参数量:1 百万)
(5s为700万个参数,VGG-16有4000万,ResNet 1.3亿个参数)

项目地址:https://github.com/xuebinqin/U-2-Net

1、人类分割模型:u2net_human_seg.pth ,下载上述模型到文件夹下./saved_models/u2net_human_seg/,没有就自己建,
2、把图片复制到./test_data/test_human_images/ 目录下
3、运行脚本python u2net_human_seg_test.py,效果图自动保存在./test_data/u2net_test_human_images_results/
(注:这个模型训练的时候(基于U2Net做了一些改进,比如数据增强),样本标注精度不是那么高,但是也比官方U2Net基于DUST-TR数据集训练得出的效果好,话说回来
这个模型用于通用人类检测分割,效果也是很牛逼,这个模型是基于数据集(Supervisely Person Dataset)预训练,数据集由5711张图片组成,有6884个高质量的标注的人体实例)
  有很多人将U2-Net活学活用,比如:人类肖像绘画[1],素描,去除背景等等。其余的不多逼逼,自己去看github介绍
咱们这里仅讨论语义分割,不是实例分割。


一、数据标注

labelImg,标注完是json格式,自己完成json文件 -> mask图片功能

  U2-Net主要测试多组数据集:

训练数据集:在DUTS-TR上训练的网络,它是DUTS数据集的一部分。DUTS-TR包含共10553张图片。目前,它是最大的用于显著目标检测的常用训练数据集。训练之前,做了平翻转来扩充这个数据集,也就是21106张图像。

评估数据集:六个常用测试数据集用于测试我们的模型,包括:DUT-OMRON、DUTS-TE、HKU-IS、ECSSD,PASCAL-S,SOD。

  DUT-OMRON:包括5168图像,其中大多数包含一个或多个前景。

  DUTS:数据集由两部分组成:DUTS-TR(训练集)和DUTS-TE(测试集)。DUTS-TE有5019张图像,用于测试。

  HKU-IS:有4447张图片,其中有多张图片地面物体。ECSSDContains1000结构复杂图像和其中许多包含大型前景对象。

  PASCAL-S:包含850幅前景复杂的图像物体和杂乱的背景。草皮只含300图像。但这是一个巨大的挑战。因为它本来就是专为图像分割而设计,很多图像都很低对比度或包含重叠的复杂前景对象图像边界。

  SOD只含300图像。但这是一个巨大的挑战。因为它本来就是专为图像分割而设计,很多图像都很低对比度或包含重叠的复杂前景对象

二、模型评价(请参考原文)

2.1、损失函数

  首先讨论语义分割的loss计算,其实就是逐像素计算交叉熵,(二分类:语义分割,多分类:实例分割),

LibTorch实战六:U2-Net实战部署<三>第1张

    上式中,权重项pos_weight作用是:平衡正负样本不均衡问题,YOLOV1目标函数中有提过,不多说。下面只讨论语义分割,不讨论实例分割。

  在计算loss过程中,都是逐像素计算loss,进行二分类。但是,对于一个区域分割、识别,边界外边的是负样本(背景),边界里边的是正样本(前景),一般都很好区分,唯独边界上的像素难以区分,那怎么解决呢?请看下面Focal loss类型损失函数。

  下式中的r(读:gamma),一般取值2,例如:当正样本标注概率为0.95的时候,采用公式(1 - p)^r降低其概率值为0.0025,这么做的初衷是:希望这种容易识别的样本像素别对最终模型产生太大贡献;再如:像素标注概率为0.5的时候,同理得出概率值为0.25,意思是:本来0.5就不高,降为0.25,相对前面0.0025,对网络贡献大得多,那么网络会对“概率为0.5”这类不易识别的像素更加重视。

  下图最后一个公式中α = 负样本/正样本

LibTorch实战六:U2-Net实战部署<三>第2张

 2.2、评价指标

IOU:如下图右边,Y轴表示标注类别,X轴表示网络预测类别,中间网络中数字表述各类别像素数量。例如:绿色框表示当前标注区域ROI1(记为true_dog)包含像素总数,黄色框表示预测区域ROI2(记为predict_dog)像素总数,所以iou_dog计算公式如下:

LibTorch实战六:U2-Net实战部署<三>第3张

   如下图,坐标是人像标注区域,右边是模型预测区域。LibTorch实战六:U2-Net实战部署<三>第4张

   下面左图就是上述两图的交基、并集。

LibTorch实战六:U2-Net实战部署<三>第5张

   一般地,在实例分割中,多余多个类别:a、b、c等类别,会分别计算IOU,然后取平均值,得到MIOU

  在U2-Net中(注:咱们这里不是实例分割!),有如下评价指标:

PR curve:通过对比网络输出Mask和标记图Mask,计算acc(TP/(TP + FP))、recall(TP/(TP+FN))

MAE: Mean Absolute Error,平均结对

LibTorch实战六:U2-Net实战部署<三>第6张还有几个懒得讲。

LibTorch实战六:U2-Net实战部署<三>第7张

三、源码解读

3.0、环境:pytoch1.7.1_cu110(和yolov5.4环境一样,直接拿来用,pytorch1.7.1+CU11.0)

3.1、准备工作

下载源码:git clone https://github.com/NathanUA/U-2-Net.git

下载预训练模型: u2net.pth (176.3 MB) or u2netp.pth (4.7 MB) 分别放到 './saved_models/u2net/' and './saved_models/u2netp/'文件夹下面,没有就自己建

训练与测试:python u2net_train.py or python u2net_test.py

3.2、训练代码解读

u2net_train.py(遇到报错请参考第六节、问题记录,我这里已经改好了):

LibTorch实战六:U2-Net实战部署<三>第8张LibTorch实战六:U2-Net实战部署<三>第9张
  1 import os
  2 
  3 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # OMP:Error
  4 import torch
  5 from torch.autograd import Variable
  6 import torch.nn as nn
  7 
  8 from torch.utils.data import DataLoader
  9 from torchvision import transforms
 10 import torch.optim as optim
 11 
 12 import glob
 13 import os
 14 
 15 from data_loader import RescaleT
 16 from data_loader import RandomCrop
 17 from data_loader import ToTensorLab
 18 from data_loader import SalObjDataset
 19 
 20 from model import U2NET
 21 from model import U2NETP
 22 
 23 # ------- 1. define loss function --------
 24 
 25 bce_loss = nn.BCELoss(size_average=True)
 26 
 27 # loss1-6:输出层上采样得到6张图,对应的loss
 28 # loss0:最终特征图的loss
 29 def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
 30     loss0 = bce_loss(d0, labels_v)
 31     loss1 = bce_loss(d1, labels_v)
 32     loss2 = bce_loss(d2, labels_v)
 33     loss3 = bce_loss(d3, labels_v)
 34     loss4 = bce_loss(d4, labels_v)
 35     loss5 = bce_loss(d5, labels_v)
 36     loss6 = bce_loss(d6, labels_v)
 37 
 38     loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
 39     print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f
" % (
 40     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
 41     loss6.data.item()))
 42 
 43     return loss0, loss
 44 
 45 
 46 # ------- 2. set the directory of training dataset --------
 47 
 48 model_name = 'u2net'  # 'u2netp'
 49 
 50 data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
 51 # tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
 52 # tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
 53 
 54 tra_image_dir = os.path.join('APDrawingGAN_test', 'im' + os.sep)
 55 tra_label_dir = os.path.join('APDrawingGAN_test', 'gt' + os.sep)
 56 
 57 image_ext = '.jpg'
 58 label_ext = '.png'
 59 
 60 model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
 61 
 62 #epoch_num = 100000
 63 # batch_size_train = 12 # error: RuntimeError: CUDA out of memory.
 64 epoch_num = 4000
 65 batch_size_train = 4 # 8G显存有点不够用
 66 batch_size_val = 1
 67 train_num = 0
 68 val_num = 0
 69 
 70 tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + label_ext)
 71 
 72 tra_lbl_name_list = []
 73 for img_path in tra_img_name_list:
 74     img_name = img_path.split(os.sep)[-1]
 75 
 76     aaa = img_name.split(".")
 77     bbb = aaa[0:-1]
 78     imidx = bbb[0]
 79     for i in range(1, len(bbb)):
 80         imidx = imidx + "." + bbb[i]
 81 
 82     tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
 83 
 84 print("---")
 85 print("train images: ", len(tra_img_name_list))
 86 print("train labels: ", len(tra_lbl_name_list))
 87 print("---")
 88 
 89 train_num = len(tra_img_name_list)
 90 
 91 # 数据预处理
 92 salobj_dataset = SalObjDataset(
 93     img_name_list=tra_img_name_list,
 94     lbl_name_list=tra_lbl_name_list,
 95     transform=transforms.Compose([
 96         RescaleT(320),    # 将原图缩放至 320*320
 97         RandomCrop(288),  # 从320*320中截取为288*288
 98         ToTensorLab(flag=0)]))
 99 # dataloader
100 salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
101 
102 # ------- 3. define model --------
103 # define the net
104 if (model_name == 'u2net'):
105     net = U2NET(3, 1)
106 elif (model_name == 'u2netp'):
107     net = U2NETP(3, 1)
108 
109 if torch.cuda.is_available():
110     net.cuda()
111 
112 # ------- 4. define optimizer --------
113 print("---define optimizer...")
114 # 学习率搞小点,Momentum 中beta1 = 0.9,RMSprop 中 beta2 = 0.999, 分母常数项设置为1e-8, 衰减率 = 0
115 optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
116 
117 # ------- 5. training process --------
118 print("---start training...")
119 ite_num = 0
120 running_loss = 0.0
121 running_tar_loss = 0.0
122 ite_num4val = 0
123 save_frq = 2000  # save the model every 2000 iterations
124 
125 if __name__ == '__main__':  # error:The "freeze_support()" line can be omitted if the progra
126     for epoch in range(0, epoch_num):
127         net.train()
128 
129         for i, data in enumerate(salobj_dataloader):
130             ite_num = ite_num + 1
131             ite_num4val = ite_num4val + 1
132 
133             inputs, labels = data['image'], data['label']
134 
135             inputs = inputs.type(torch.FloatTensor)
136             labels = labels.type(torch.FloatTensor)
137 
138             # wrap them in Variable
139             if torch.cuda.is_available():
140                 inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
141             else:
142                 inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
143 
144             # y zero the parameter gradients
145             optimizer.zero_grad()
146 
147             # forward + backward + optimize
148             d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
149             # 可以看到,7张mask都是直接和label图计算交叉熵
150             # loss2:最终mask图的loss
151             # loss:其余6个输出mask的loss之和
152             loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
153 
154             loss.backward()
155             optimizer.step()
156 
157             # # print statistics
158             running_loss += loss.data.item()
159             running_tar_loss += loss2.data.item()
160 
161             # del temporary outputs and loss
162             del d0, d1, d2, d3, d4, d5, d6, loss2, loss
163 
164             print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
165                 epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val,
166                 running_tar_loss / ite_num4val))
167 
168             if ite_num % save_frq == 0:
169                 torch.save(net.state_dict(), model_dir + model_name + "_bce_itr_%d_train_%3f_tar_%3f.pth" % (
170                 ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
171                 running_loss = 0.0
172                 running_tar_loss = 0.0
173                 net.train()  # resume train
174                 ite_num4val = 0
View Code

3.3、测试代码解读

u2net_test.py 

LibTorch实战六:U2-Net实战部署<三>第10张LibTorch实战六:U2-Net实战部署<三>第11张
  1 import os
  2 from skimage import io, transform
  3 import torch
  4 import torchvision
  5 from torch.autograd import Variable
  6 import torch.nn as nn
  7 import torch.nn.functional as F
  8 from torch.utils.data import Dataset, DataLoader
  9 from torchvision import transforms  # , utils
 10 # import torch.optim as optim
 11 
 12 import numpy as np
 13 from PIL import Image
 14 import glob
 15 
 16 from data_loader import RescaleT
 17 from data_loader import ToTensor
 18 from data_loader import ToTensorLab
 19 from data_loader import SalObjDataset
 20 
 21 from model import U2NET  # full size version 173.6 MB
 22 from model import U2NETP  # small version u2net 4.7 MB
 23 
 24 
 25 # normalize the predicted SOD probability map
 26 def normPRED(d):
 27     ma = torch.max(d)
 28     mi = torch.min(d)
 29 
 30     dn = (d - mi) / (ma - mi)
 31 
 32     return dn
 33 
 34 
 35 def save_output(image_name, pred, d_dir):
 36     predict = pred
 37     predict = predict.squeeze()
 38     predict_np = predict.cpu().data.numpy()
 39 
 40     im = Image.fromarray(predict_np * 255).convert('RGB')
 41     img_name = image_name.split(os.sep)[-1]
 42     image = io.imread(image_name)
 43     imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
 44 
 45     pb_np = np.array(imo)
 46 
 47     aaa = img_name.split(".")
 48     bbb = aaa[0:-1]
 49     imidx = bbb[0]
 50     for i in range(1, len(bbb)):
 51         imidx = imidx + "." + bbb[i]
 52 
 53     imo.save(d_dir + imidx + '.png')
 54 
 55 
 56 def main():
 57     # --------- 1. get image path and name ---------
 58     model_name = 'u2net'  # u2netp
 59 
 60     image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
 61     prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
 62     model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
 63 
 64     img_name_list = glob.glob(image_dir + os.sep + '*')
 65     print(img_name_list)
 66 
 67     # --------- 2. dataloader ---------
 68     # 1. dataloader
 69     test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
 70                                         lbl_name_list=[],
 71                                         transform=transforms.Compose([RescaleT(320), # 缩放到了320
 72                                                                       ToTensorLab(flag=0)])
 73                                         )
 74     test_salobj_dataloader = DataLoader(test_salobj_dataset,
 75                                         batch_size=1,
 76                                         shuffle=False,
 77                                         num_workers=1)
 78     # --------- 3. model define ---------
 79     if (model_name == 'u2net'):
 80         print("...load U2NET---173.6 MB")
 81         net = U2NET(3, 1)
 82     elif (model_name == 'u2netp'):
 83         print("...load U2NEP---4.7 MB")
 84         net = U2NETP(3, 1)
 85 
 86     if torch.cuda.is_available():
 87         net.load_state_dict(torch.load(model_dir))
 88         net.cuda()
 89     else:
 90         net.load_state_dict(torch.load(model_dir, map_location='cpu'))
 91     net.eval()
 92 
 93     # 统计参数量级(by shiruiyu)
 94     num_params = 0
 95     for param in net.parameters():
 96         num_params += param.numel()
 97     print("numbers of parameters: ", num_params / 1e6, "百万")
 98 
 99     # --------- 4. inference for each image ---------
100     for i_test, data_test in enumerate(test_salobj_dataloader):
101 
102         print("inferencing:", img_name_list[i_test].split(os.sep)[-1])
103 
104         inputs_test = data_test['image']
105         inputs_test = inputs_test.type(torch.FloatTensor)
106 
107         if torch.cuda.is_available():
108             inputs_test = Variable(inputs_test.cuda())
109         else:
110             inputs_test = Variable(inputs_test)
111 
112         d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
113 
114         # normalization
115         pred = d1[:, 0, :, :] # 这里是推理,所以仅处理最终特征图
116         pred = normPRED(pred)
117 
118         # save results to test_results folder
119         if not os.path.exists(prediction_dir):
120             os.makedirs(prediction_dir, exist_ok=True)
121         save_output(img_name_list[i_test], pred, prediction_dir)
122 
123         del d1, d2, d3, d4, d5, d6, d7
124 
125 
126 if __name__ == "__main__":
127     main()
View Code

3.4、网络模型解读

记得连带参考上图4,u2net.py

先看有哪些函数,如下截图:

LibTorch实战六:U2-Net实战部署<三>第12张

LibTorch实战六:U2-Net实战部署<三>第13张

一定要对比图看,已经注释得很详细了

LibTorch实战六:U2-Net实战部署<三>第14张LibTorch实战六:U2-Net实战部署<三>第15张
  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 
  5 # note:最新U2Net代码输入图像直接插值为320*320,后续没有进行截图
  6 # 下文中,in_ch, mid_ch, out_ch分别表示初始、中间、末端特征图channels维度
  7 # CBR组合:conv + BN + Relu(可能有空洞卷积)
  8 class REBNCONV(nn.Module):
  9     def __init__(self, in_ch=3, out_ch=3, dirate=1):
 10         super(REBNCONV, self).__init__()
 11         # dilation 空洞卷积参数
 12         self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
 13         self.bn_s1 = nn.BatchNorm2d(out_ch)
 14         self.relu_s1 = nn.ReLU(inplace=True)
 15 
 16     def forward(self, x):
 17         hx = x
 18         xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
 19 
 20         return xout
 21 
 22 
 23 # 上采样:输入、输出在channel维度上是一致的,仅仅缩放W、H维度
 24 # upsample tensor 'src' to have the same spatial size with tensor 'tar'
 25 def _upsample_like(src, tar):
 26     src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
 27 
 28     return src
 29 
 30 
 31 # 图4-stage1
 32 ### RSU-7 ###
 33 class RSU7(nn.Module):  # UNet07DRES(nn.Module):
 34 
 35     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
 36         super(RSU7, self).__init__()
 37 
 38         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 39 
 40         self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 41         self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 42 
 43         self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
 44         self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 45 
 46         self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
 47         self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 48 
 49         self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
 50         self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 51 
 52         self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
 53         self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 54 
 55         self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
 56 
 57         self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
 58 
 59         self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 60         self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 61         self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 62         self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 63         self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 64         self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 65 
 66     def forward(self, x):
 67         hx = x
 68         hxin = self.rebnconvin(hx)
 69 
 70         hx1 = self.rebnconv1(hxin)
 71         hx = self.pool1(hx1)
 72 
 73         hx2 = self.rebnconv2(hx)
 74         hx = self.pool2(hx2)
 75 
 76         hx3 = self.rebnconv3(hx)
 77         hx = self.pool3(hx3)
 78 
 79         hx4 = self.rebnconv4(hx)
 80         hx = self.pool4(hx4)
 81 
 82         hx5 = self.rebnconv5(hx)
 83         hx = self.pool5(hx5)
 84 
 85         hx6 = self.rebnconv6(hx)
 86         # hx7:图4-stage1中最右边、最小的蓝色块
 87         hx7 = self.rebnconv7(hx6)
 88         # 下面有多个cat操作
 89         # 对应图4-stage1中的符号“+”
 90         hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
 91         hx6dup = _upsample_like(hx6d, hx5)
 92 
 93         hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
 94         hx5dup = _upsample_like(hx5d, hx4)
 95 
 96         hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
 97         hx4dup = _upsample_like(hx4d, hx3)
 98 
 99         hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
100         hx3dup = _upsample_like(hx3d, hx2)
101 
102         hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
103         hx2dup = _upsample_like(hx2d, hx1)
104         # hx1d:图4-stage1中最右边紫色块
105         hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
106 
107         return hx1d + hxin
108 
109 
110 # 图4-stage2
111 ### RSU-6 ###
112 class RSU6(nn.Module):  # UNet06DRES(nn.Module):
113 
114     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115         super(RSU6, self).__init__()
116 
117         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
118 
119         self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
120         self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
121 
122         self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
123         self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124 
125         self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
126         self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127 
128         self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
129         self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130 
131         self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
132 
133         self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
134 
135         self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
136         self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137         self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138         self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139         self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
140 
141     def forward(self, x):
142         hx = x
143 
144         hxin = self.rebnconvin(hx)
145 
146         hx1 = self.rebnconv1(hxin)
147         hx = self.pool1(hx1)
148 
149         hx2 = self.rebnconv2(hx)
150         hx = self.pool2(hx2)
151 
152         hx3 = self.rebnconv3(hx)
153         hx = self.pool3(hx3)
154 
155         hx4 = self.rebnconv4(hx)
156         hx = self.pool4(hx4)
157 
158         hx5 = self.rebnconv5(hx)
159 
160         hx6 = self.rebnconv6(hx5)
161 
162         hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
163         hx5dup = _upsample_like(hx5d, hx4)
164 
165         hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
166         hx4dup = _upsample_like(hx4d, hx3)
167 
168         hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
169         hx3dup = _upsample_like(hx3d, hx2)
170 
171         hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
172         hx2dup = _upsample_like(hx2d, hx1)
173 
174         hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
175 
176         return hx1d + hxin
177 
178 
179 # 图4-stage3
180 ### RSU-5 ###
181 class RSU5(nn.Module):  # UNet05DRES(nn.Module):
182 
183     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
184         super(RSU5, self).__init__()
185 
186         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
187 
188         self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
189         self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
190 
191         self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
192         self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
193 
194         self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
195         self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196 
197         self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
198 
199         self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
200 
201         self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
202         self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203         self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204         self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
205 
206     def forward(self, x):
207         hx = x
208 
209         hxin = self.rebnconvin(hx)
210 
211         hx1 = self.rebnconv1(hxin)
212         hx = self.pool1(hx1)
213 
214         hx2 = self.rebnconv2(hx)
215         hx = self.pool2(hx2)
216 
217         hx3 = self.rebnconv3(hx)
218         hx = self.pool3(hx3)
219 
220         hx4 = self.rebnconv4(hx)
221 
222         hx5 = self.rebnconv5(hx4)
223 
224         hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
225         hx4dup = _upsample_like(hx4d, hx3)
226 
227         hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
228         hx3dup = _upsample_like(hx3d, hx2)
229 
230         hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
231         hx2dup = _upsample_like(hx2d, hx1)
232 
233         hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
234 
235         return hx1d + hxin
236 
237 
238 # 图4-stage4
239 ### RSU-4 ###
240 class RSU4(nn.Module):  # UNet04DRES(nn.Module):
241 
242     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243         super(RSU4, self).__init__()
244 
245         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
246 
247         self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
248         self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249 
250         self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
251         self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
252 
253         self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
254 
255         self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
256 
257         self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258         self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
259         self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
260 
261     def forward(self, x):
262         hx = x
263 
264         hxin = self.rebnconvin(hx)
265 
266         hx1 = self.rebnconv1(hxin)
267         hx = self.pool1(hx1)
268 
269         hx2 = self.rebnconv2(hx)
270         hx = self.pool2(hx2)
271 
272         hx3 = self.rebnconv3(hx)
273 
274         hx4 = self.rebnconv4(hx3)
275 
276         hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
277         hx3dup = _upsample_like(hx3d, hx2)
278 
279         hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
280         hx2dup = _upsample_like(hx2d, hx1)
281 
282         hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
283 
284         return hx1d + hxin
285 
286 
287 # 图4-stage5、6
288 ### RSU-4F ###
289 class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
290 
291     def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292         super(RSU4F, self).__init__()
293 
294         self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
295 
296         self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
297         self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
298         self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
299 
300         self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
301 
302         self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
303         self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
304         self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
305 
306     def forward(self, x):
307         hx = x
308 
309         hxin = self.rebnconvin(hx)
310 
311         hx1 = self.rebnconv1(hxin)
312         hx2 = self.rebnconv2(hx1)
313         hx3 = self.rebnconv3(hx2)
314 
315         hx4 = self.rebnconv4(hx3)
316 
317         hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
318         hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
319         hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
320 
321         return hx1d + hxin
322 
323 
324 # 大模型4千万个参数(和小模型对比区别如下:)
325 # 网络宽度,也就是每一层卷积核数量是2、4、8倍关系(倍数随着层数呈现指数增长)
326 # 怪不得体积大小如此之大
327 ##### U^2-Net ####
328 class U2NET(nn.Module):
329     def __init__(self, in_ch=3, out_ch=1):
330         super(U2NET, self).__init__()
331 
332         self.stage1 = RSU7(in_ch, 32, 64)
333         self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
334 
335         self.stage2 = RSU6(64, 32, 128)
336         self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
337 
338         self.stage3 = RSU5(128, 64, 256)
339         self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
340 
341         self.stage4 = RSU4(256, 128, 512)
342         self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
343 
344         self.stage5 = RSU4F(512, 256, 512)
345         self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
346 
347         self.stage6 = RSU4F(512, 256, 512)
348 
349         # decoder
350         self.stage5d = RSU4F(1024, 256, 512)
351         self.stage4d = RSU4(1024, 128, 256)
352         self.stage3d = RSU5(512, 64, 128)
353         self.stage2d = RSU6(256, 32, 64)
354         self.stage1d = RSU7(128, 16, 64)
355 
356         self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
357         self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
358         self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
359         self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
360         self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
361         self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
362 
363         self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
364 
365     def forward(self, x):
366         hx = x  # torch.Size([1, 3, 320, 320]) note:原图上输入是:1*3*288*288, 和下面是一样的懒得改了
367         # print('hx.shape = ', hx.shape)
368 
369         # stage 1(En_1)
370         hx1 = self.stage1(hx)  # torch.Size([1, 64, 320, 320])
371         hx = self.pool12(hx1)  # torch.Size([1, 64, 160, 160])
372 
373         # stage 2(En_2)
374         hx2 = self.stage2(hx)  # torch.Size([1, 128, 160, 160])
375         hx = self.pool23(hx2)  # torch.Size([1, 128, 80, 80])
376 
377         # stage 3(En_3)
378         hx3 = self.stage3(hx)  # torch.Size([1, 256, 80, 80])
379         hx = self.pool34(hx3)  # torch.Size([1, 256, 40, 40])
380 
381         # stage 4(En_4)
382         hx4 = self.stage4(hx)  # torch.Size([1, 512, 40, 40])
383         hx = self.pool45(hx4)  # torch.Size([1, 512, 20, 20])
384 
385         # stage 5(En_5)
386         hx5 = self.stage5(hx)  # torch.Size([1, 512, 20, 20])
387         hx = self.pool56(hx5)  # torch.Size([1, 512, 10, 10])
388 
389         # stage 6(En_6)
390         hx6 = self.stage6(hx)  # torch.Size([1, 512, 10, 10])
391         hx6up = _upsample_like(hx6, hx5)  # torch.Size([1, 512, 20, 20])
392 
393         # -------------------- decoder --------------------
394         # De_5
395         hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))  # torch.Size([1, 512, 20, 20])
396         hx5dup = _upsample_like(hx5d, hx4)  # torch.Size([1, 512, 40, 40])
397         # De_4
398         hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))  # torch.Size([1, 256, 40, 40])
399         hx4dup = _upsample_like(hx4d, hx3)  # torch.Size([1, 256, 80, 80])
400         # De_3
401         hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))  # torch.Size([1, 128, 80, 80])
402         hx3dup = _upsample_like(hx3d, hx2)  # torch.Size([1, 128, 160, 160])
403         # De_2
404         hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) # torch.Size([1, 64, 160, 160])
405         hx2dup = _upsample_like(hx2d, hx1)  # torch.Size([1, 64, 320, 320])
406         # De_1
407         hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))  # torch.Size([1, 64, 320, 320])
408 
409         # side output
410         # 0倍上采样
411         d1 = self.side1(hx1d)  # torch.Size([1, 1, 320, 320])
412         # 2倍上采样
413         d2 = self.side2(hx2d)  # torch.Size([1, 1, 160, 160])
414         d2 = _upsample_like(d2, d1)  # torch.Size([1, 1, 320, 320])
415         # 5倍上采样
416         d3 = self.side3(hx3d)  # torch.Size([1, 1, 80, 80])
417         d3 = _upsample_like(d3, d1)  # torch.Size([1, 1, 320, 320])
418         # 8倍上采样
419         d4 = self.side4(hx4d)  # torch.Size([1, 1, 40, 40])
420         d4 = _upsample_like(d4, d1)  # torch.Size([1, 1, 320, 320])
421         # 16倍上采样
422         d5 = self.side5(hx5d)  # torch.Size([1, 1, 20, 20])
423         d5 = _upsample_like(d5, d1)  # torch.Size([1, 1, 320, 320])
424         # 32倍上采样
425         d6 = self.side6(hx6)  # torch.Size([1, 1, 10, 10])
426         d6 = _upsample_like(d6, d1)  # torch.Size([1, 1, 320, 320])
427         # concat + 1×1卷积
428         d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))  # torch.Size([1, 1, 320, 320])
429         # torch.sigmoid()
430         return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
431 
432 
433 # 小模型1百万个参数
434 ### U^2-Net small ###
435 class U2NETP(nn.Module):
436 
437     def __init__(self, in_ch=3, out_ch=1):
438         super(U2NETP, self).__init__()
439 
440         self.stage1 = RSU7(in_ch, 16, 64)
441         self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
442 
443         self.stage2 = RSU6(64, 16, 64)
444         self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
445 
446         self.stage3 = RSU5(64, 16, 64)
447         self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
448 
449         self.stage4 = RSU4(64, 16, 64)
450         self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
451 
452         self.stage5 = RSU4F(64, 16, 64)
453         self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
454 
455         self.stage6 = RSU4F(64, 16, 64)
456 
457         # decoder
458         self.stage5d = RSU4F(128, 16, 64)
459         self.stage4d = RSU4(128, 16, 64)
460         self.stage3d = RSU5(128, 16, 64)
461         self.stage2d = RSU6(128, 16, 64)
462         self.stage1d = RSU7(128, 16, 64)
463 
464         self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
465         self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
466         self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
467         self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
468         self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
469         self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
470 
471         self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
472 
473     def forward(self, x):
474         hx = x
475 
476         # stage 1
477         hx1 = self.stage1(hx)
478         hx = self.pool12(hx1)
479 
480         # stage 2
481         hx2 = self.stage2(hx)
482         hx = self.pool23(hx2)
483 
484         # stage 3
485         hx3 = self.stage3(hx)
486         hx = self.pool34(hx3)
487 
488         # stage 4
489         hx4 = self.stage4(hx)
490         hx = self.pool45(hx4)
491 
492         # stage 5
493         hx5 = self.stage5(hx)
494         hx = self.pool56(hx5)
495 
496         # stage 6
497         hx6 = self.stage6(hx)
498         hx6up = _upsample_like(hx6, hx5)
499 
500         # decoder
501         hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
502         hx5dup = _upsample_like(hx5d, hx4)
503 
504         hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
505         hx4dup = _upsample_like(hx4d, hx3)
506 
507         hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
508         hx3dup = _upsample_like(hx3d, hx2)
509 
510         hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
511         hx2dup = _upsample_like(hx2d, hx1)
512 
513         hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
514 
515         # side output
516         d1 = self.side1(hx1d)
517 
518         d2 = self.side2(hx2d)
519         d2 = _upsample_like(d2, d1)
520 
521         d3 = self.side3(hx3d)
522         d3 = _upsample_like(d3, d1)
523 
524         d4 = self.side4(hx4d)
525         d4 = _upsample_like(d4, d1)
526 
527         d5 = self.side5(hx5d)
528         d5 = _upsample_like(d5, d1)
529 
530         d6 = self.side6(hx6)
531         d6 = _upsample_like(d6, d1)
532 
533         d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
534 
535         return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
View Code

四、Libtorch部署

模型导出python脚本:

export_u2net.py

(这里只给出导出CPU版本,实际上,在libtorch中无论是CPU还是GPU都是可以用这个导出的CPU模型,因为模型、数据是可以导入GPU中)

LibTorch实战六:U2-Net实战部署<三>第16张LibTorch实战六:U2-Net实战部署<三>第17张
 1 import os
 2 import torch
 3 from model import U2NET  # full size version 173.6 MB
 4 
 5 
 6 def main():
 7     model_name = 'u2net'
 8     model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '_human_seg', model_name + '_human_seg.pth')
 9 
10     if model_name == 'u2net':
11         print("...load U2NET---173.6 MB")
12         net = U2NET(3, 1)
13 
14     net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
15     net.eval()
16 
17     # --------- model 序列化 ---------
18     #example = torch.zeros(1, 3, 512, 512).to(device='cuda')
19     example = torch.zeros(1, 3, 512, 512)
20     torch_script_module = torch.jit.trace(net, example)
21     torch_script_module.save('human2-cpu.pt')
22     print('over')
23 
24 
25 if __name__ == "__main__":
26     main()
View Code

部署代码:

配置文件Config.yaml

LibTorch实战六:U2-Net实战部署<三>第18张LibTorch实战六:U2-Net实战部署<三>第19张
 1 %YAML:1.0
 2 # note: 1、修改文件名时,记得保留符号 "",变量不需要该符号
 3 #       2、图分辨率 > 
 4 #       3、本文件注释须单独一行
 5 #       4、项目中所有读取、保存的本地数据都默认在dir: "D://Data//"下
 6 
 7 # data目录
 8 dir: "D:\Data\"
 9 
10 # 原图
11 srcImgFile: "img_1589.png"
12 
13 
14 # ****************************************************************** 深度学习 ***********************************************************************
15 # 风格转换模型文件名  
16 styleModelFile: "D:\U-2-Net-master\human1-gpu.pt"
View Code

配置文件代码:Config.h、Config.cpp

LibTorch实战六:U2-Net实战部署<三>第20张LibTorch实战六:U2-Net实战部署<三>第21张
 1 #ifndef CONFIG_H
 2 #define CONFIG_H
 3 
 4 #include<opencv2/opencv.hpp>
 5 #include<iostream>
 6 
 7 class Config
 8 {
 9 public:
10     Config(const std::string& yamlFile);
11     ~Config();
12 
13     template<typename T>
14     T get(const std::string& key)
15     {
16         return T(this->m_fileStorage[key]);
17     }
18 
19 private:
20     std::string m_yamlFile;
21     cv::FileStorage m_fileStorage;
22 };
23 
24 #endif // !Config_H
View Code
LibTorch实战六:U2-Net实战部署&lt;三&gt;第22张LibTorch实战六:U2-Net实战部署&lt;三&gt;第23张
 1 #include "Config.h"
 2 
 3 Config::Config(const std::string& yamlFile):
 4     m_yamlFile(yamlFile)
 5 {
 6     this->m_fileStorage.open(this->m_yamlFile, cv::FileStorage::READ);
 7     if (!this->m_fileStorage.isOpened())
 8     {
 9         std::cerr << "open default.yaml failurely!" << std::endl;
10         system("pause");
11     }
12 }
13 
14 Config::~Config()
15 {
16 }
View Code

人像语义分割:U2Net_Human.cpp,这里又报错(),请参考:《 libtorch在windows下场见错误整理总结》https://i.cnblogs.com/posts/edit-done;postId=14687275

LibTorch实战六:U2-Net实战部署&lt;三&gt;第24张LibTorch实战六:U2-Net实战部署&lt;三&gt;第25张
  1 #include<opencv2/opencv.hpp>
  2 #include<torch/torch.h>
  3 #include<torch/script.h>
  4 #include"Config.h"
  5 
  6 torch::Tensor normPRED(torch::Tensor d)
  7 {
  8     at::Tensor ma, mi;
  9     torch::Tensor dn;
 10     ma = torch::max(d);
 11     mi = torch::min(d);
 12     dn = (d - mi) / (ma - mi);
 13     return dn;
 14 }
 15 
 16 void  bgr_u2net(cv::Mat& image_src, cv::Mat& result, torch::jit::Module& model)
 17 {
 18     auto device = torch::Device("cuda");
 19     //   auto image_bgr = cv::imread("bg11.png");
 20     //    auto xt = cv::imread("xt2.jpg");
 21     cv::Mat  image_src1 = image_src.clone();
 22     cv::resize(image_src, image_src, cv::Size(320, 320));
 23     cv::cvtColor(image_src, image_src, cv::COLOR_RGB2BGR);
 24     //    cv::cvtColor(image_src,image_src,cv::COLOR_BGR2RGB);
 25 
 26     torch::Tensor tensor_image_src = torch::from_blob(image_src.data, { image_src.rows, image_src.cols,3 }, torch::kByte);
 27     //    torch::Tensor tensor_image_bgr = torch::from_blob(image_bgr.data, {image_bgr.rows, image_bgr.cols,3},torch::kByte);
 28     torch::Tensor tensor_bgr = torch::from_blob(image_src1.data, { image_src1.rows, image_src1.cols,3 }, torch::kByte);
 29     tensor_image_src = tensor_image_src.permute({ 2,0,1 });
 30     tensor_image_src = tensor_image_src.toType(torch::kFloat);
 31     tensor_image_src = tensor_image_src.div(255);
 32     tensor_image_src = tensor_image_src.unsqueeze(0);
 33     //    tensor_image_bgr = tensor_image_bgr.permute({2,0,1});
 34     //    tensor_image_bgr = tensor_image_bgr.toType(torch::kFloat);
 35     //    tensor_image_bgr = tensor_image_bgr.div(255);
 36     //    tensor_image_bgr = tensor_image_bgr.unsqueeze(0);
 37     tensor_bgr = tensor_bgr.permute({ 2,0,1 });
 38     tensor_bgr = tensor_bgr.toType(torch::kFloat);
 39     tensor_bgr = tensor_bgr.div(255);
 40     tensor_bgr = tensor_bgr.unsqueeze(0);
 41     //    cv::imshow("image",tensor_image_bgr)
 42 
 43     auto src = tensor_image_src.to(device);
 44     //    auto bgr =   tensor_image_bgr.to(device);
 45     auto src_copy = tensor_bgr.to(device);
 46 
 47     auto outputs = model.forward({ src }).toTuple()->elements();
 48 
 49     auto pred = outputs[0].toTensor();
 50 
 51 
 52     //    pha = normPRED_(pha);
 53     //    auto fgr = outputs[1].toTensor();
 54     //    auto res_tensor = (pred * src + (1-pred)* torch::ones_like(src));
 55     //    double endtime=(double)(end-start)/CLOCKS_PER_SEC;
 56     //    std::cout<<"time:"<<endtime<<std::endl;
 57     //    auto res_tensor = (pred * src + (1-pred)*torch::tensor({120/255, 255/255, 155/255}).to(device).view({1,3,1,1}));
 58     auto res_tensor = (pred * torch::ones_like(src));
 59     res_tensor = normPRED(res_tensor);
 60     res_tensor = res_tensor.squeeze(0).detach();
 61     res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8);
 62     res_tensor = res_tensor.to(torch::kCPU);
 63     //    cv::Mat result( image_bgr.rows,image_bgr.cols, CV_32FC3,fgr.data_ptr());
 64     cv::Mat resultImg(res_tensor.size(1), res_tensor.size(2), CV_8UC3);
 65     std::memcpy((void*)resultImg.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel());
 66     //    result=resultImg.clone();
 67     //    cv::cvtColor(result,result,cv::COLOR_BGR2RGB);
 68 
 69     cv::resize(resultImg, resultImg, cv::Size(image_src1.cols, image_src1.rows), cv::INTER_LINEAR);
 70     //   cv:: Mat element = getStructuringElement(cv::MORPH_RECT, cv::Size(15,15));
 71     //    cv::dilate(resultImg, resultImg, element);
 72     //    cv::threshold(resultImg, resultImg, 130, 255, cv::THRESH_BINARY);
 73     //    cv::imwrite("pha.jpg", resultImg);
 74     torch::Tensor tensor_result = torch::from_blob(resultImg.data, { resultImg.rows, resultImg.cols,3 }, torch::kByte);
 75     tensor_result = tensor_result.permute({ 2,0,1 });
 76     tensor_result = tensor_result.toType(torch::kFloat);
 77     tensor_result = tensor_result.div(255);
 78     tensor_result = tensor_result.unsqueeze(0);
 79     //    torch::Tensor  c=(tensor_result>220/255);
 80 
 81     //    tensor_result>200/255;
 82     ;
 83     //    tensor_result[tensor_result>=200/255]=1;
 84     //    res_tensor = (c * tensor_bgr -c* torch::ones_like(tensor_bgr)+torch::ones_like(tensor_bgr) );
 85     res_tensor = (tensor_result * tensor_bgr + (1 - tensor_result) * torch::ones_like(tensor_bgr));
 86     //    res_tensor = (tensor_result * tensor_bgr +(1-tensor_result)* tensor_image_bgr );
 87     res_tensor = res_tensor.squeeze(0).detach();
 88     res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8);
 89     res_tensor = res_tensor.to(torch::kCPU);
 90     //    cv::Mat result( image_bgr.rows,image_bgr.cols, CV_32FC3,fgr.data_ptr());
 91     cv::Mat resultImg1(res_tensor.size(1), res_tensor.size(2), CV_8UC3);
 92     std::memcpy((void*)resultImg1.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel());
 93     result = resultImg1.clone();
 94 
 95 
 96 }
 97 
 98 int main()
 99 {
100     // load srcImg
101     Config cfg("Config.yaml");
102     cv::Mat srcImg = cv::imread(cfg.get<std::string>("srcImgFile"), -1);
103     cv::Mat srcImg_;
104     cv::resize(srcImg, srcImg_, cv::Size(512, 512));
105 
106     std::string str = cfg.get<std::string>("styleModelFile");
107 
108     // load model of cpu
109     torch::jit::script::Module styleModule;
110     // load style model
111     auto device_type = at::kCPU;
112     if (torch::cuda::is_available()) {
113         std::cout << "gpu" << std::endl;
114         device_type = at::kCUDA;
115     }
116     try
117     {
118         styleModule = torch::jit::load(str);
119         styleModule.to(device_type);
120     }
121     catch (const c10::Error& e)
122     {
123         std::cerr << "errir code: -2, error loading the model
";
124         return -1;
125     }
126     cv::Mat dstImg;
127     bgr_u2net(srcImg_, dstImg, styleModule);
128 
129     cv::imshow("dstImg", dstImg);
130     cv::waitKey(0);
131 
132     return 1;
133 }
View Code

更新下U2Net_Human.cpp,似乎对libtorch还不够纯熟。

LibTorch实战六:U2-Net实战部署&lt;三&gt;第26张LibTorch实战六:U2-Net实战部署&lt;三&gt;第27张
  1 #include<opencv2/opencv.hpp>
  2 #include<torch/torch.h>
  3 #include<torch/script.h>
  4 #include"Config.h"
  5 
  6 torch::Tensor normPRED(torch::Tensor d) 
  7 {
  8     at::Tensor ma, mi;
  9     torch::Tensor dn;
 10     ma = torch::max(d);
 11     mi = torch::min(d);
 12     dn = (d - mi) / (ma - mi);
 13     return dn;
 14 }
 15 
 16 void  bgr_u2net(cv::Mat& image_src, cv::Mat& result, torch::jit::Module& model) 
 17 {
 18     auto device = torch::Device("cuda");
 19   
 20     cv::Mat  image_src1 = image_src.clone();
 21     cv::resize(image_src, image_src, cv::Size(320, 320));
 22     //cv::cvtColor(image_src, image_src, cv::COLOR_RGB2BGR);
 23     cv::cvtColor(image_src,image_src,cv::COLOR_BGR2RGB);
 24     
 25     torch::Tensor tensor_image_src = torch::from_blob(image_src.data, { image_src.rows, image_src.cols, 3 }, torch::kByte);
 26     //  torch::Tensor tensor_image_bgr = torch::from_blob(image_bgr.data, {image_bgr.rows, image_bgr.cols,3},torch::kByte);
 27     torch::Tensor tensor_bgr = torch::from_blob(image_src1.data, { image_src1.rows, image_src1.cols,3 }, torch::kByte);
 28     tensor_image_src = tensor_image_src.permute({ 2,0,1 }); // RGB -> BGR互换,有点多余
 29     tensor_image_src = tensor_image_src.toType(torch::kFloat);
 30     tensor_image_src = tensor_image_src.div(255);
 31     // [3, 320, 320] 
 32     tensor_image_src = tensor_image_src.unsqueeze(0); // 拿掉第一个维度
 33     // [1, 3, 320, 320]
 34     std::cout << tensor_image_src.sizes() << std::endl;
 35 
 36     tensor_bgr = tensor_bgr.permute({ 2,0,1 });
 37     tensor_bgr = tensor_bgr.toType(torch::kFloat);
 38     tensor_bgr = tensor_bgr.div(255);
 39     tensor_bgr = tensor_bgr.unsqueeze(0);
 40 
 41     auto src = tensor_image_src.to(device);
 42     //    auto bgr =   tensor_image_bgr.to(device);
 43     //auto src_copy = tensor_bgr.to(device);
 44 
 45     auto outputs = model.forward({ src }).toTuple()->elements();
 46 
 47     auto pred = outputs[0].toTensor();
 48     
 49     auto res_tensor = (pred * torch::ones_like(src));
 50     
 51     std::cout << torch::ones_like(src).sizes() << std::endl;
 52     std::cout << src.sizes() << std::endl;
 53     
 54     res_tensor = normPRED(res_tensor);
 55     res_tensor = res_tensor.squeeze(0).detach();
 56     res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8);
 57     res_tensor = res_tensor.to(torch::kCPU);
 58     //    cv::Mat result( image_bgr.rows,image_bgr.cols, CV_32FC3,fgr.data_ptr());
 59     cv::Mat resultImg(res_tensor.size(1), res_tensor.size(2), CV_8UC3);
 60     std::memcpy((void*)resultImg.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel());
 61     //    result=resultImg.clone();
 62     //    cv::cvtColor(result,result,cv::COLOR_BGR2RGB);
 63 
 64     cv::resize(resultImg, resultImg, cv::Size(image_src1.cols, image_src1.rows), cv::INTER_LINEAR);
 65     //   cv:: Mat element = getStructuringElement(cv::MORPH_RECT, cv::Size(15,15));
 66     //    cv::dilate(resultImg, resultImg, element);
 67     //    cv::threshold(resultImg, resultImg, 130, 255, cv::THRESH_BINARY);
 68     //    cv::imwrite("pha.jpg", resultImg);
 69     torch::Tensor tensor_result = torch::from_blob(resultImg.data, { resultImg.rows, resultImg.cols,3 }, torch::kByte);
 70     tensor_result = tensor_result.permute({ 2,0,1 });
 71     tensor_result = tensor_result.toType(torch::kFloat);
 72     tensor_result = tensor_result.div(255);
 73     tensor_result = tensor_result.unsqueeze(0);
 74     //    torch::Tensor  c=(tensor_result>220/255);
 75 
 76     //    tensor_result>200/255;
 77     ;
 78     //    tensor_result[tensor_result>=200/255]=1;
 79     //    res_tensor = (c * tensor_bgr -c* torch::ones_like(tensor_bgr)+torch::ones_like(tensor_bgr) );
 80     res_tensor = (tensor_result * tensor_bgr + (1 - tensor_result) * torch::ones_like(tensor_bgr));
 81     //    res_tensor = (tensor_result * tensor_bgr +(1-tensor_result)* tensor_image_bgr );
 82     res_tensor = res_tensor.squeeze(0).detach();
 83     res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8);
 84     res_tensor = res_tensor.to(torch::kCPU);
 85     //    cv::Mat result( image_bgr.rows,image_bgr.cols, CV_32FC3,fgr.data_ptr());
 86     cv::Mat resultImg1(res_tensor.size(1), res_tensor.size(2), CV_8UC3);
 87     std::memcpy((void*)resultImg1.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel());
 88     result = resultImg1.clone();
 89 
 90 
 91 }
 92 
 93 int main()
 94 {
 95     // load srcImg
 96     Config cfg("Config.yaml");
 97     cv::Mat srcImg = cv::imread(cfg.get<std::string>("srcImgFile"), -1);
 98     cv::Mat srcImg_;
 99     cv::resize(srcImg, srcImg_, cv::Size(512, 512));
100     if (srcImg_.channels() == 4)
101     {
102         cv::cvtColor(srcImg_, srcImg_, cv::COLOR_BGRA2BGR);
103     }
104 
105     std::string str = cfg.get<std::string>("styleModelFile");
106 
107     // load model of cpu
108     torch::jit::script::Module styleModule;
109     // load style model
110     auto device_type = at::kCPU;
111     if (torch::cuda::is_available()) {
112         std::cout << "gpu" << std::endl;
113         device_type = at::kCUDA;
114     }
115     try
116     {
117         styleModule = torch::jit::load(str);
118         styleModule.to(device_type);
119     }
120     catch (const c10::Error& e)
121     {
122         std::cerr << "errir code: -2, error loading the model
";
123         return -1;
124     }
125     cv::Mat dstImg;
126     bgr_u2net(srcImg_, dstImg, styleModule);
127 
128     cv::imshow("dstImg", dstImg);
129     cv::waitKey(0);
130 
131     return 1;
132 }
View Code

五、性能分析 

六、问题记录

6.1、u2net_train.py报错问题:

1、 OMP:Error

解决:在文件第一行添加如下代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # OMP:Error

2、爆显存 error: RuntimeError: CUDA out of memory.

batch_size_train = 12 # 将12改为1

3、error:The "freeze_support()" line can be omitted if the progra

if __name__ == '__main__': # error:The "freeze_support()" line can be omitted if the progra
    for epoch in range(0, epoch_num): #在 这个for循环前面加上一行,如上所示
     ......
reference:
[1] 肖像绘画:https://www.cvpy.net/studio/cv/func/DeepLearning/sketch/sketch/page/

免责声明:文章转载自《LibTorch实战六:U2-Net实战部署&amp;lt;三&amp;gt;》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇web自动化测试第10步:获取浏览器弹窗alert、自定义弹窗以及其操作WPF-悬浮窗(类似于360)下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

【NS-3学习】ns3-模拟基础:关键概念,日志,命令行参数

前言 本篇博客先介绍在仿真过程中会使用到的一些关键概念,然后介绍便于调试仿真脚本的常用技术:日志、命令行参数。 关键概念 节点 在因特网术语中,主机(终端)是指任何一台连接到网络的计算设备。ns-3并非一个专门的因特网模拟器,而是一个网络模拟器。为此不采用术语“主机”,因为这个词太容易让人联想到因特网以及相关协议。因此,选用其他术语:节点。 可以将节点理...

centos 配置rsync+inotify数据实时同步2

一、Rsync服务简介 1. 什么是Rsync 它是一个远程数据同步工具,它在同步文件的同时,可通过LAN/WAN快速同步多台主机间的文件。Rsync使用所谓的“rsync算法”来使本地和远程两个主机之间的文件达到同步,这个算法只传送两个文件的不同部分,而不是每次都整份传送,因此速度相当快。可以保持原来文件的权限、时间、软硬链接等附加信息。 2. Rsyn...

JAVA中使用动态链接库须知

好不容易 使用cygwin+windows打了个包,放到JAVA中使用。但是还是会报各种错。   JAVA中使用.so文件 static{ System.loadLibrary("test"); //加载libtest.so } 以上是网上的经典用法,很简单的一句。但是用的时候还是出错了。。。 主要还是.so路径找不到。找了大半天,终于,问题解决了...

stressapptest使用介绍【转】

转自:https://blog.csdn.net/vichie2008/article/details/50457566 stressapptest 其是:Stressful Application Test (stressapptest) 的简称 Stressful Application Test试图让来自处理器和I/O到内存的数据尽量随机化,以创...

GitLab实战操作指南

一、Git原理 1、Git是什么? Git是目前世界上最先进的分布式版本控制系统(没有之一)。 2、Git有什么特点? 简单来说就是:高端大气上档次! 3、GIt与SVN区别 SVN管理: 属于集中式版本控制系统,版本库是集中存放在中央服务器,必须联网才能工作,脱离局域网,采取互联网提交大文件会比较慢   Git管理: 属于分布式版本控制系统,根...

Python的路径引用

1、以HOME目录为准,进行跳转 sys.path.append(os.path.dirname(__file__) + os.sep + '../') from config import swordfishconf from utils import log from utils.mysql_base import MySQLBase 将程序的HO...