以图搜图的基本原理:

以图搜图是一种基于内容的图像检索 (CBIR) 技术²,它的特点是无需关键字就能理解图像的相关内容,主要依赖于 AI 算法,目前一些排名较好的图像分类算法可以到达 99% 准确率(TOP5)³。本文将利用 AI 模型提取图像特征向量,通过特征向量计算来完成以图搜图。

一 ,Towhee & Milvus

Towhee (http://github.com/towhee-io/towhee)提供开箱即用的 Embedding 流水线可以将任何非结构化数据(图像,视频,音频等)转为特征向量,通过 Towhee 我们运行一条流水线就能轻松得到特征向量。

Milvus(http://github.com/milvus-io/milvus) 是一个开源的向量数据库项目,它支持丰富的向量索引算法和向量计算方式,轻松实现对数百万、数十亿甚至数万亿向量的相似性搜索,具有高度灵活、稳定可靠以及高速查询等特点。

通过 Towhee + Milvus 就可以实现端到端的图像等非结构化数据分析。我们先使用 Towhee 完成非结构化数据的特征向量提取,然后 Milvus 负责存储并搜索向量,最终获取与查询数据最相似的结果并展示。

Towhee 和 Milvus 的安装:

注意:Milvus 支持单机安装和集群安装,本文使用docker-compose(http://milvus.io/docs/v2.0.x/install_standalone-docker.md)方式安装单机 Milvus,在此之前请先检查本机环境的软硬件条件(http://milvus.io/docs/v2.0.x/prerequisite-docker.md)。

#安装 Towhee

$ pip install towhee

#安装单机版 Milvus
$ wget http://github.com/milvus-io/milvus/releases/download/v2.0.2/milvus-standalone-docker-compose.yml -O docker-compose.yml
$ docker-compose up -d

Towhee 支持图像 Embedding,音频 Embedding,视频 Embedding 等非结构化数据特征提取的方法,这些都被称为 Towhee 的算子(Operator),算子是流水线(Pipeline)中的单个节点,一个图像特征提取流水线就可以通过连接 image_decode(http://towhee.io/image-decode/cv2) 算子和 image_embedding.timm(http://towhee.io/image-embedding/timm) 算子实现,其中 Embedding 算子可以通过指定 model_name="resnet50" 利用 ResNet50 模型生成特征向量

import towhee
towhee.glob['path']('./test/lion/n02129165_13728.JPEG') \
.image_decode['path', 'img']() \
.image_embedding.timm['img', 'vec'](model_name='resnet50') \
.select['img', 'vec']() \
.show()

接下来在 Milvus 数据库中创建集合(Collection),集合中的 Fields 包含两列:id 和 embedding,其中 id 是集合的主键。另外我们可以为 embedding 创建 IVF_FLAT (http://milvus.io/docs/v2.0.x/index.md#IVF_FLAT) 基于量化的索引,其中索引的参数是 nlist=2048,计算方式是 "L2" 欧式距离:

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
def create_milvus_collection(collection_name, dim):
connections.connect(host='127.0.0.1', port='19530')
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='reverse image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{"nlist":2048}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('reverse_image_search', 2048)

图像数据入库

Towhee 不光拥有丰富的算子来处理非结构化数据,还提供了简单好用的接口来处理各种数据,当然也集成了 Milvus 的一些基本用法,通过在“流水线”中连接这些算子或接口,图像入库操作将变得十分 Milvus 简单。

import towhee
dc = (
towhee.read_csv('reverse_image_search.csv') #读取 CSV 格式的表格,包含了 id,path 和 label 列
.runas_op['id', 'id'](func=lambda x: int(x)) #将每一行的 id 从 str 类型转为 int 类型
.image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式
.image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
.tensor_normalize['vec', 'vec']() #将向量进行归一化
.to_milvus['id', 'vec'](collection=collection, batch=100) #将 id 和 vec 批量 100 条插入到 Milvus 集合
)

查询图像并展示

查询图像时需要的图像处理算子与前面类似,包括 image_decode image_embedding.timm tensor_normalize ,而在最后分析检索结果时,需用到数据准备部分定义好的 read_images 函数,通过指定 runas_op 中的 func 将该函数加入到 Towhee 流水线中。

(towhee.glob['path']('./test/w*/*.JPEG') #读取满足指定模式下的所有图片数据为 path
.image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式
.image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
.tensor_normalize['vec', 'vec']() #将向量进行归一化
.milvus_search['vec', 'result'](collection=collection, limit=5) #在 Milvus 集合中搜索向量,并返回结果
.runas_op['result', 'result_img'](func=read_images) #处理 Milvus 的检索结果,最终返回图像用于展示
.select['img', 'result_img']() #选择指定列;
.show()
)

1,选用resnet网络提取图像特征

2,milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息

3,异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;

(总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)

图像向量化

功能:图像向量化
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
from numpy import linalg as LA
import time
model = ResNet50(weights = 'imagenet' )
# model.summary()
def img2feature(img_path, input_dim = 224 ): # 图像路径???图像数据
img = image.load_img(img_path, target_size = (input_dim, input_dim))
x = image.img_to_array(img)
x = np.expand_dims(x, axis = 0 )
x = preprocess_input(x)
x = model.predict(x)
x = x / LA.norm(x)
return x
def main():
img_path = '1.jpg'
t0 = time.time()
res = img2feature(img_path)
print (time.time() - t0, res.shape)
# print(res, type(res), res.shape)
if __name__ = = "__main__" :
main()

milvus表的操作

# coding:utf-8
from functools import reduce
import numpy as np
import time
from img2feature import img2feature
from pymilvus import (
connections, list_collections,
FieldSchema, CollectionSchema, DataType,
Collection, utility
field_name = 'image_feature'
host = '***.***.***.***'
port = '19530'
dim = 1000
default_fields = [
FieldSchema(name = "milvus_id" , dtype = DataType.INT64, is_primary = True ),
FieldSchema(name = "feature" , dtype = DataType.FLOAT_VECTOR, dim = dim),
FieldSchema(name = "create_time" , dtype = DataType.INT64)
# create_table
def create_table():
connections.connect(host = host, port = port)
# create collection
default_schema = CollectionSchema(fields = default_fields, description = "test collection" )
print (f "\nCreate collection..." )
collection = Collection(name = field_name, schema = default_schema)
print (f "\nCreate index..." )
default_index = { "index_type" : "FLAT" , "params" : { "nlist" : 128 }, "metric_type" : "L2" }
collection.create_index(field_name = "feature" , index_params = default_index)
print ( print (f "\nCreate index...is OKOKOKOKOK" ))
collection.load()
# insert data
def insert_data():
connections.connect(host = host, port = port)
default_schema = CollectionSchema(fields = default_fields, description = "test collection" )
collection = Collection(name = field_name, schema = default_schema)
vectors = img2feature( '1.jpg' ).tolist()[ 0 ]
print ( type (vectors), len (vectors))
data1 = [
[ 123 ],
[vectors],
[ int (time.time())]
collection.insert(data1)
print ( 'insert compete' )
# search data
def search_data():
print ( 'search' )
connections.connect(host = host, port = port)
collection = Collection(name = field_name)
print ( '连接成功' )
# 首次查询建立索引和load()
# default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
# print(f"\nCreate index...")
# collection.create_index(field_name="feature", index_params=default_index)
# print(print(f"\nCreate index...is OKOKOKOKOK"))
# collection.load()
# exit()
vectors = img2feature( '1.jpg' ).tolist()[ 0 ]
topK = 10
search_params = { "metric_type" : "L2" , "params" : { "nprobe" : 10 }}
res = collection.search(
[vectors],
"feature" ,
search_params,
topK,
"create_time > {}" . format ( 0 ),
output_fields = [ "milvus_id" ]
print ( '>>>' , res)
for hits in res:
print ( len (hits))
for hit in hits:
print (hit)
print ( '查询结束' )
def show_nums():
connections.connect(host = host, port = port)
collection = Collection(name = field_name)
print ( 'ok' )
print (collection.num_entities)
# delete data
def delete_table():
connections.connect(host = host, port = port)
default_schema = CollectionSchema(fields = default_fields, description = "test collection" )
collection = Collection(name = field_name, schema = default_schema)
print ( '>>>' , utility.has_collection(field_name))
collection.drop()
print ( '>>>' , utility.has_collection(field_name))
if __name__ = = "__main__" :
t1 = time.time()
# create_table()
# insert_data()
# search_data()
show_nums()
# delete_table()
print ( 'time cost: {}' . format (time.time() - t1))

图像添加、搜索服务

from rest_framework.views import APIView as View
from kpdjango.response import SucessAPIResponse, ErrorAPIResponse
from kpmysql.base import Kpmysql
from core import search_image
import kplog
import logging
log = logging.getLogger( "console" )
class add_image(View):
def post( self , requests):
try :
db = Kpmysql.connect( "db168" )
cur = db.cursor()
image_info = requests.POST.get( 'image_info' )
image_path = requests.POST.get( 'image_path' )
sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"
cur.execute(sql, (image_path, image_info))
db.commit()
log.info( '添加图像成功:{}-{}' . format (image_path, image_info))
return SucessAPIResponse(msg = "Success" )
except Exception as e:
log.info( '添加图像失败:{}' . format (e))
return ErrorAPIResponse(msg = "Fail" )
class search_image(View):
def post( self , requests):
try :
image_path = requests.POST.get( 'image_path' )
res = search_image(image_path)
log.info( '查询图像成功:{}-{}' . format (image_path, res))
return SucessAPIResponse(msg = "Success" , data = { "data" : res})
except Exception as e:
log.info( '查询图像成功:{}' . format (e))
return ErrorAPIResponse(msg = "Fail" )

图像异步批量加载

import time, datetime
from kpmysql.base import Kpmysql
from core import insert_data_many
from concurrent.futures import ThreadPoolExecutor
import redis
from conf.setting import REDIS
from core import str2time
import kplog
import logging
log = logging.getLogger( "console" )
log_addimgs = logging.getLogger( "console_addimgs" )
def worker(datas):
try :
redis_cli = redis.Redis(host = REDIS.get( 'host' ), port = REDIS.get( 'port' ), password = REDIS.get( 'password' ),
db = REDIS.get( 'db' ))
dics = []
ids = []
for data in datas:
if redis_cli.zscore( 'image_search' , str (data[ 0 ])): # 基于redis去重
continue
dics.append({ 'image_path' : data[ 1 ], 'create_time' : data[ 2 ]})
ids.append((data[ 0 ]))
redis_cli.zadd( 'image_search' , { str (data[ 0 ]): str2time(data[ 2 ])})
# 数据插入milvus
insert_data_many(dics)
# 更新 set t_image_search_image_add_log is_load=1
sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""
db168 = Kpmysql.connect( "db168" )
cur168 = db168.cursor()
cur168.executemany(sql_update, ids)
db168.commit()
except Exception as e:
print (e)
def main():
max_workers = 20 # 最大线程数
pool = ThreadPoolExecutor(max_workers = max_workers, thread_name_prefix = 'Thread' )
task_list = []
init_time = datetime.datetime.now() - datetime.timedelta(hours = 13 )
create_time_init = '2020-2-22 00:00:00'
while True :
now = datetime.datetime.now()
diff = now - init_time
if diff.seconds > 3600 :
# 加载 t_image_search_image_add_log where is_load=0 数据
db168 = Kpmysql.connect( "db168" )
cur168 = db168.cursor()
sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""
cur168.execute(sql, create_time_init)
datas = cur168.fetchall()
create_time_init = datas[ - 1 ][ 2 ]
while True :
for _i, _n in enumerate (task_list):
if _n.done():
task_list.pop(_i)
if len (task_list) < int (max_workers * 0.9 ):
break
task_list.append(pool.submit(worker, datas))
init_time = now
time.sleep( 600 )
if __name__ = = "__main__" :
main()

1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;

2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;