使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务

摘要:
本文将训练后的VGG16模型通过转移学习应用于图像多标签分类。项目数据来自Kaggle,每个图片可以同时属于多个标签。使用Fscore对模型的准确性进行量化,如下表所示:标签预测为正值,预测为负值,真值为正值,TPFN真值为负值,FPTN真值为负数。例如,如果实际标记为,而预测标记为,则TP=2,FN=1,FP=2,TN=1$$Precision=frac{TP}{TP+FP},ext{}Recall=frac{TP}{TP+FN},ext{}F{\_}Score=frac}*Prescription*Recall}{Recall+eta^2*Precision}$$$eta$越小,Fscore中Precision的权重越大。当$eta$等于0时,Fscore变为Precision$。eta$越大,Fscore中Recall的权重越大。当$eta$趋于无穷大时,Fscore变为Recall。

本文通过迁移学习将训练好的VGG16模型应用到图像的多标签分类问题中。该项目数据来自于Kaggle,每张图片可同时属于多个标签。模型的准确度使用F score进行量化,如下表所示:

标签预测为Positive(1)预测为Negative(0)
真值为Positive(1)TPFN
真值为Negative(0)FPTN

例如真实标签是(1,0,1,1,0,0), 预测标签是(1,1,0,1,1,0), 则TP=2, FN=1, FP=2, TN=1。$$Precision=frac{TP}{TP+FP}, ext{  }Recall=frac{TP}{TP+FN}, ext{  }F{\_}score=frac{(1+eta^2)*Presicion*Recall}{Recall+eta^2*Precision}$$其中$eta$越小,F score中Precision的权重越大,$eta$等于0时F score就变为Precision;$eta$越大,F score中Recall的权重越大,$eta$趋于无穷大时F score就变为Recall。可以在Keras中自定义该函数(y_pred表示预测概率):

from tensorflow.keras import backend
 
# calculate fbeta score for multi-label classification
def fbeta(y_true, y_pred, beta=2):
    # clip predictions
    y_pred = backend.clip(y_pred, 0, 1)
    # calculate elements for each sample
    tp = backend.sum(backend.round(backend.clip(y_true * y_pred, 0, 1)), axis=1)
    fp = backend.sum(backend.round(backend.clip(y_pred - y_true, 0, 1)), axis=1)
    fn = backend.sum(backend.round(backend.clip(y_true - y_pred, 0, 1)), axis=1)
    # calculate precision
    p = tp / (tp + fp + backend.epsilon())
    # calculate recall
    r = tp / (tp + fn + backend.epsilon())
    # calculate fbeta, averaged across samples
    bb = beta ** 2
    fbeta_score = backend.mean((1 + bb) * (p * r) / (bb * p + r + backend.epsilon()))
    return fbeta_score

此外在损失函数的使用上多标签分类和多类别(multi-class)分类也有区别,多标签分类使用binary_crossentropy,假设一个样本的真实标签是(1,0,1,1,0,0),预测概率是(0.2, 0.3, 0.4, 0.7, 0.9, 0.2): $$binary{\_}crossentropy ext{  }loss=-(ln 0.2 + ln 0.7 + ln 0.4 + ln 0.7 + ln 0.1 + ln 0.8)/6=0.96$$另外多标签分类输出层的激活函数选择sigmoid而非softmax。模型架构如下所示:

from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Model

def define_model(in_shape=(128, 128, 3), out_shape=17):
    # load model
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=in_shape)
    # mark loaded layers as not trainable
    for layer in base_model.layers: layer.trainable = False
    # make the last block trainable
    tune_layers = [layer.name for layer in base_model.layers if layer.name.startswith('block5_')]
    for layer_name in tune_layers: base_model.get_layer(layer_name).trainable = True
    # add new classifier layers
    flat1  = Flatten()(base_model.layers[-1].output)
    class1 = Dense(128, activation='relu', kernel_initializer='he_uniform')(flat1)
    output = Dense(out_shape, activation='sigmoid')(class1)
    # define new model
    model = Model(inputs=base_model.input, outputs=output)
    # compile model
    opt = Adam(learning_rate=1e-3)
    model.compile(optimizer=opt, loss='binary_crossentropy', metrics=[fbeta])
    model.summary()
    return model

Kaggle网站上下载数据并解压,将其处理成可被模型读取的数据格式

使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务第1张使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务第2张
from os import listdir
from numpy import zeros, asarray, savez_compressed
from pandas import read_csv
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# create a mapping of tags to integers given the loaded mapping file
def create_tag_mapping(mapping_csv):
    labels = set() # create a set of all known tags
    for i in range(len(mapping_csv)):
        tags = mapping_csv['tags'][i].split(' ') # convert spaced separated tags into an array of tags
        labels.update(tags) # add tags to the set of known labels
    labels = sorted(list(labels)) # convert set of labels to a sorted list 
    # dict that maps labels to integers, and the reverse
    labels_map = {labels[i]:i for i in range(len(labels))}
    inv_labels_map = {i:labels[i] for i in range(len(labels))}
    return labels_map, inv_labels_map

# create a mapping of filename to a list of tags
def create_file_mapping(mapping_csv):
    mapping = dict()
    for i in range(len(mapping_csv)):
        name, tags = mapping_csv['image_name'][i], mapping_csv['tags'][i]
        mapping[name] = tags.split(' ')
    return mapping

# create a one hot encoding for one list of tags
def one_hot_encode(tags, mapping):
    encoding = zeros(len(mapping), dtype='uint8') # create empty vector
    # mark 1 for each tag in the vector
    for tag in tags: encoding[mapping[tag]] = 1
    return encoding

# load all images into memory
def load_dataset(path, file_mapping, tag_mapping):
    photos, targets = list(), list()
    # enumerate files in the directory
    for filename in listdir(path):
        photo = load_img(path + filename, target_size=(128,128)) # load image
        photo = img_to_array(photo, dtype='uint8') # convert to numpy array
        tags = file_mapping[filename[:-4]] # get tags
        target = one_hot_encode(tags, tag_mapping) # one hot encode tags
        photos.append(photo)
        targets.append(target)
    X = asarray(photos, dtype='uint8')
    y = asarray(targets, dtype='uint8')
    return X, y

filename = 'train_v2.csv' # load the target file
mapping_csv = read_csv(filename)
tag_mapping, _ = create_tag_mapping(mapping_csv) # create a mapping of tags to integers
file_mapping = create_file_mapping(mapping_csv) # create a mapping of filenames to tag lists
folder = 'train-jpg/' # load the jpeg images
X, y = load_dataset(folder, file_mapping, tag_mapping)
print(X.shape, y.shape)
savez_compressed('planet_data.npz', X, y) # save both arrays to one file in compressed format
View Code

接下来再建立两个辅助函数,第一个函数用来分割训练集和验证集,第二个函数用来画出模型在训练过程中的学习曲线

使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务第3张使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务第4张
import numpy as np
from matplotlib import pyplot
from sklearn.model_selection import train_test_split

# load train and test dataset
def load_dataset():
    # load dataset
    data = np.load('planet_data.npz')
    X, y = data['arr_0'], data['arr_1']
    # separate into train and test datasets
    trainX, testX, trainY, testY = train_test_split(X, y, test_size=0.3, random_state=1)
    print(trainX.shape, trainY.shape, testX.shape, testY.shape)
    return trainX, trainY, testX, testY

# plot diagnostic learning curves
def summarize_diagnostics(history):
    # plot loss
    pyplot.subplot(121)
    pyplot.title('Cross Entropy Loss')
    pyplot.plot(history.history['loss'], color='blue', label='train')
    pyplot.plot(history.history['val_loss'], color='orange', label='test')
    # plot accuracy
    pyplot.subplot(122)
    pyplot.title('Fbeta')
    pyplot.plot(history.history['fbeta'], color='blue', label='train')
    pyplot.plot(history.history['val_fbeta'], color='orange', label='test')
    pyplot.show()
View Code

使用数据扩充技术(Data Augmentation)对模型进行训练

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint

trainX, trainY, testX, testY = load_dataset() # load dataset
# create data generator using augmentation
# vertical flip is reasonable since the pictures are satellite images
train_datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True, rotation_range=90, preprocessing_function=preprocess_input)
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
# prepare generators
train_it = train_datagen.flow(trainX, trainY, batch_size=128)
test_it = test_datagen.flow(testX, testY, batch_size=128)
# define model
model = define_model()
# fit model
# When one epoch ends, the validation generator will yield validation_steps batches, then average the evaluation results of all batches
checkpointer = ModelCheckpoint(filepath='./weights.best.vgg16.hdf5', verbose=1, save_best_only=True)
history = model.fit_generator(train_it, steps_per_epoch=len(train_it), validation_data=test_it, validation_steps=len(test_it), 
                              epochs=15, callbacks=[checkpointer], verbose=0)
# evaluate optimal model
# For simplicity, the validation set is used to test the model here. In fact an entirely new test set should have been used. 
model.load_weights('./weights.best.vgg16.hdf5') #load stored optimal coefficients
loss, fbeta = model.evaluate_generator(test_it, steps=len(test_it), verbose=0)
print('> loss=%.3f, fbeta=%.3f' % (loss, fbeta)) # loss=0.108, fbeta=0.884
model.save('final_model.h5')
# learning curves
summarize_diagnostics(history)

使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务第5张

 蓝线代表训练集,黄线代表验证集

免责声明:文章转载自《使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇DevExpress WinForm MVVM数据和属性绑定指南(Part 1)浅析AnyCast网络技术下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

JSP页面之${fn:}内置函数

函数列表: 函数名 函数说明 使用举例 fn:contains 判断字符串是否包含另外一个字符串 <c:if test="${fn:contains(name, searchString)}"> fn:containsIgnoreCase 判断字符串是否包含另外一个字符串(大小写无关) <c:if test="${fn:con...

Python:Lasso方法、GM预测模型、神经网络预测模型之财政收入影响因素分析及预测

问题重述 通过研究,发现影响某市目前及未来地方财源的因素。结合文中目标:(1)选择模型,找出影响财政收入的关键因素;(2)基于关键因素,选择预测方法、模型预测未来收入。 具体来讲 本文分析了地方财政收入、增值税收入、营业税收入、企业所得税收入、个人所得税收入的影响因素并对未来两年采用灰色预测(GM(1,1))并以已有年度序列训练神经网络(NN),再以得到的...

Zero-shot learning(零样本学习)

一、介绍 在传统的分类模型中,为了解决多分类问题(例如三个类别:猫、狗和猪),就需要提供大量的猫、狗和猪的图片用以模型训练,然后给定一张新的图片,就能判定属于猫、狗或猪的其中哪一类。但是对于之前训练图片未出现的类别(例如牛),这个模型便无法将牛识别出来,而ZSL就是为了解决这种问题。在ZSL中,某一类别在训练样本中未出现,但是我们知道这个类别的特征,然后通...

牛客网2017校招真题在线编程之合唱团问题——动态规划问题首秀

先贴题目 题目描述 有 n 个学生站成一排,每个学生有一个能力值,牛牛想从这 n 个学生中按照顺序选取 k 名学生,要求相邻两个学生的位置编号的差不超过 d,使得这 k 个学生的能力值的乘积最大,你能返回最大的乘积吗? 输入描述: 每个输入包含 1 个测试用例。每个测试数据的第一行包含一个整数 n (1 <= n <= 50),表示学生的个数...

驱动模块(2)——模块信息与调试

一、查看内核模块信息 相关命令:modprobe、insmod、rmmod、modinfo、lsmod 1.查看内核所有内置模块# cat /lib/modules/$(uname -r)/modules.builtin kernel/arch/arm64/crypto/sha1-ce.ko kernel/arch/arm64/crypto/sha2-ce...

拓端数据tecdat|R语言贝叶斯线性回归和多元线性回归构建工资预测模型

原文链接:http://tecdat.cn/?p=21641  工资模型 在劳动经济学领域,收入和工资的研究为从性别歧视到高等教育等问题提供了见解。在本文中,我们将分析横断面工资数据,以期在实践中使用贝叶斯方法,如BIC和贝叶斯模型来构建工资的预测模型。 加载包 在本实验中,我们将使用dplyr包探索数据,并使用ggplot2包进行数据可视化。我们也可...