
多模态大模型Qwen-VL和MiniCPM-Llama3-V-2_5初体验
QwenVL这个号称是国内最好的多模态大模型,阿里通义千问系列多模态大模型之一。QwenVL系列有3个大模型,分别是Qwen-VL-Chat&Qwen-VL-Plus & Qwen-VL-Max,其中Qwen-VL-Chat开源了代码以及模型权重,而Qwen-VL-Plus & Qwen-VL-Max这两个效果更加的模型,并未开源,但是可以通过🤗🤖网页端APP和API访问,而我们重点关注的是开
目录
一、QwenVL和MiniCPM-Llama3-V-2_5模型简介
二、QwenVL和MiniCPM-Llama3-V-2_5原理简析
三、QwenVL和MiniCPM-Llama3-V-2_5功能测试和效果
借着公司做视觉方向的业务,学习和体验了一下多模态大模型。主要是对Qwen-VL和MiniCPM-Llama3-V-2_5做预研,了解视觉特征是怎么融合进LLM大模型的,同时验证一下上述两个模型在OCR能力上有那些惊艳的效果。本篇博客对模型进行简介、探究了一下视觉特征和LLM结合方式、以及微调实验的一些结论。
一、QwenVL和MiniCPM-Llama3-V-2_5模型简介
Qwen-VL-Chat
QwenVL这个号称是国内最好的多模态大模型,阿里通义千问系列多模态大模型之一。QwenVL系列有3个大模型,分别是Qwen-VL-Chat&Qwen-VL-Plus & Qwen-VL-Max,其中Qwen-VL-Chat开源了代码以及模型权重,而Qwen-VL-Plus & Qwen-VL-Max这两个效果更加的模型,并未开源,但是可以通过🤗、🤖、网页端、APP 和 API访问,而我们重点关注的是开源的Qwen-VL-Chat,后文简称QwenVL。首先模型架构层面,LLM来自QWen-7B,visionTransformer采用s ViT-bigG,总体参数大约9.5B,fp16保持的权重文件包总计19G。它具备很多视觉方面的能力,docVqa(文档理解),图片问答、图片理解、OCR多图、多图问答、多轮问答以及图片box理解定位等,可以看论文中的举例(Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond):
训练过程如下:
第一阶段:冻结QwenLLM模型权重,训练Q_Transformer模型权重以及Vit模型权重
第二阶段:所有的模型权重都放开进行多任务预训练
第三阶段:冻结Vit模型权重,其他的部分权重进行有监督的下游任务训练。
注意的是他们的训练数据情况:
中文的只占了1.4B中的23.24%,其中还有220M约67.7%的私有数据,其实从后续效果体验来看,QwenVl在中文的OCR效果并不惊艳,从网页上体验的Qwen-VL-Plus & Qwen-VL-Max效果确实很好,猜测是训练数据和模型规模有优化。
MiniCPM-Llama3-V-2_5
这个模型地清华智谱和面壁智能合作出品的多模态大模型,它的文本大模型是基于llama3模型架构,visionTransformer是siglip-so400m模型,模型参数总量共计8B,采用fp16存储的模型参数占用16G空间。并没有找到官方论文,下面就贴一下其演示的多模态的能力:
文档理解
表格理解
以及多轮对话等等。
二、QwenVL和MiniCPM-Llama3-V-2_5原理简析
主流视觉特征和LLM结合
首先需要理解一下文本大模型LLM和视觉模型是怎么结合在一起的,目前主流的方法是:图像特征和文本特征融合,然后输入到LLM大模型中进行训练和推理。具体到融合的策略而言,一种是采用Query-Transformer,一种是MLP。本文中涉及到的Qwen-VL和MiniCPM-Llama3-V-2_5都是基于Query-Transformer来进行特征融合的。现阶段之前的大模型大多都是采用Query-Transformer来压缩图像信息,减少图片tokens占用的数量,往后应该是MLP越来越受欢迎的,具体的分析可以看知乎上的专业解析多模态大语言模型(MLLM)为什么最近的工作中用BLIP2中Q-Former结构的变少了,这里我就不分析了,没有做过实验也分析不了哈。融合原理如下图(引自知乎文章多模态大模型:视觉模型与LLM的结合之路(四))
从图中可以得出图像和文本大模型的结合方式
1、图片经过视觉Encoder模型得到图片的视觉特征img_emb
2、img_emb经过一个压缩变换层adapter,把img_emb的维度和文本特征prompt emb对齐(矩阵的最后一维相等),同时为了减少对齐后img_emb占用太多的token位置,把img_emb压缩到一个固定的token数量
3、img_emb和prompt emb对齐后直接concat起来,输入到LLM
Qwen-VL模型原理
首先上模型结构图
Qwen-VL模型LLM使用Qwen-7B预训练模型,32层QWenBlock;视觉编码器使用openclip的ViT-bigG预训练模型,48层TransformerBlock;adapter中含有256个query node,一层attention,同时添加了2D的位置编码,注意到q和k的位置编码是不一样的,直接原因是q(256)和k(1024)的长度不一样。
模型的输入中是直接把图片路径img_path和文本prompt作为一个整体输入的,后续的流程中从输入中把img_path提取出来,然后读取图片得到img,输入到visual模块中(VisionTransformer)。visual模块包含了提取img特征的TransformerBlock以及作为图片特征压缩和变换的resampler模块(也就是adapter),img经过TransformerBlock后得到图片的高阶特征img_feature,再经过resampler就把img_feature和text_embedding维度做了对齐,同时也压缩到固定的token上。text_embedding和img_embedding直接concat就得到text和img融合后的特征,如上图所示中的,粉红色方块有256个代表的是一张img被压缩和变换的token_embedding。整体上理解了这个流程也就差不多理解了Qwen-VL是怎么把视觉特征和文本特征融合在一起然后输入到LLM模块中进行处理的,至于ViT-bigG预训练模型、Qwen-7B以及adapter更多细节本文不予讨论。
text_embeding和img_embedding(对齐后的)融合代码:
hidden_states = self.drop(hidden_states).clone()
if fake_images is not None:
hidden_states = hidden_states + images.mean()*0
elif images is not None:
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1 : b] = images[idx]
直接把text_embedding中分配给img特征的位置直接赋值成相应的img_embedding。
为了更加清晰的看清楚整个text_img和img_融合过程,下面展示一下模型输入后的中间产物以及模块的矩阵维度变化。
文本经过tokenizer后的产物
<img></img>
<img>/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/image_with_text.jpg</img>
<img>img_path</img>占用258个token位置,其中img_path占用256个token位置,后续这些位置的embedding直接按照上述代码中的方式直接替换为对齐后的img_embedding。
输入到embedding的示意图如下
把原始输入tokenize后得到的input_ids,从中decode得到img_path,读取图片,经过resize、conv2d to patches一系列操作,把img tokenize化,可以输入到visionTransformer预训练模型中提取img_feature,最后经过adpater中的线性层把维度和文本特征维度对齐,才通过crossattention把img_feature压缩到固定token数量上,减少img特征对token的占用。
MiniCPM-Llama3-V-2_5模型原理
模型结构和Qwen-VL相比较大差不差,多模态的融合思想都是一样的,都是把img通过变换压缩到固定的token数,而且都是采用Query-Transformer来融合的。不同点是LLM模块MiniCPM-Llama3-V-2_5采用的llama3;visionTransformer模块MiniCPM-Llama3-V-2_5采用的是siglip-so400m预训练模型;adapter几乎相同都是img特征经过Linear层和文本特征对齐后,在经过cross_attn模块压缩。
如上图,模型结构大致一致,resampler中的cross_attn计算的时候,只有K加了位置编码。一个比较大的不同是MiniCPM-Llama3-V-2_5对输入图片根据尺寸大小做动态调整,每张原始输入会变换成多张新图,每张新图占用96个tokens,而不是像Qwen-VL统一的把各种不同尺寸的图片全部插值为448*448(scale_resolution*scale_resolution)的大小然后压缩为固定的256个token。其处理流程如下:
1、如果原始图片面积小于448*448(scale_resolution*scale_resolution),按照如下规则进行上采样,保持采样后的图片高宽比和原始图片的高宽比保持不变,并且长宽要能整除patch_size(=14),得到原始图片扩大的图source_upsample_image,把source_upsample_image为输入图片的中间表达输入到模型中进行处理。
2、如果原始图片面积大于448*448(scale_resolution*scale_resolution),按照上述同样的规则进行下采样,得到原始图片缩小的图source_downsample_image;为了保留更多的img信息,按照一定的规律,再次对原始图片进行变换(主要是扩大),然后从变换后的img中完整的分割出一定数量的patches子图,把这些patches子图和source_downsample_image一起作为输入图片的中间表达输入到模型中进行处理,source_downsample_image和patche子图可能尺寸不一样,需要补充padding。切分patches需要遵循一下原则(具体的实现需要去看模型源码):
a、原始图片变换后,能完整的切成m个patches,其中2<=m<=9,m=x*y,x>=1,y>=1(x和y表示子图的行列布局);并且x/y的比例和原始图片的高宽之比变化最小;
b、patch子图高和宽都能被patch_size(=14)整除,并且patch的子图heigth = round(scale_resolution * height/w )
以上的设计我猜想是为了尽可能的保留原始输入图片的信息,扩大后的图片不要扭曲高宽比几乎不变,有尽量少占用token数,注意一张输入图片经过上述梳理后变为1+n个小图,每个小图最终都压缩为96个token。
为了更加清晰的看清楚整个text_img和img_融合过程,同样看一下输入后的中间产物
文本经过tokenizer后的产物
图片占位符
prompt:请识别图片中的全部文本img:500*500
最终输入到tokenizer中的文本就是下图的图片占位符+中文prompt组成的
img则被处理为448*448的source_downsample_image和2张630*322的patch
同样的输入到模型后,中间变量的维度变化如下:
推理代码
简单的给出一个推理代码,如下
QwenVL
图片分别是:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoModelForCausalLM, AutoTokenizer
import base64
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
if __name__ == '__main__':
model_path = "/AI_TEAM/yanghuang/pretrain_models/torch/Qwen-VL-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", bf16=True,
trust_remote_code=True).eval()
img_path = "/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/test.jpg"
prompt = "请输出图片中的文字"
img_base64_list = []
# 第一轮对话
query = tokenizer.from_list_format([
{'image': img_path}, # Either a local path or an url
{'text': f'{prompt}'},
])
print(f"query---: {[query]}")
response, history = model.chat(tokenizer, query=query, history=None)
print("response---",[response])
print("history---", [history])
print("*"*100)
query = tokenizer.from_list_format([
{'image': "/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/image_with_text.jpg"}, # Either a local path or an url
{'text': f'好的,那新的图片中是什么内容'},
])
print(f"query---: {[query]}")
response, history = model.chat(tokenizer, query=query, history=history)
print("response---", [response])
print("history---", [history])
结果如下
MiniCPM-Llama3-V-2_5
图片:
代码:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import json
from PIL import Image
import base64
import io
from transformers import AutoTokenizer, AutoModel
from peft import AutoPeftModelForCausalLM
class MiniCPMV2_5:
def __init__(self, model_path, adapter_path=None) -> None:
if adapter_path:
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_path, trust_remote_code=True).to(dtype=torch.bfloat16)
vpm_resampler_embedtokens_weight = torch.load(f"{adapter_path}/vpm_resampler_embedtokens.pt")
self.model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False)
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
self.model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def chat(self, input, img_base64_list=None):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
if __name__ == '__main__':
model_path = '/AI_TEAM/yanghuang/pretrain_models/torch/MiniCPM-Llama3-V-2_5'
adapter_path = "/AI_TEAM/yanghuang/workspace/project/MiniCPM-V/output/solid_bg_imgs_20240621/checkpoint-72000"
model = MiniCPMV2_5(model_path=model_path)
img_path = "./500_500_test.jpg"
im_64 = img2base64(img_path)
msgs = [{"role": "user", "content": "请识别图片中的全部文本"}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
response = model.chat(inputs)
print(response)
三、QwenVL和MiniCPM-Llama3-V-2_5功能测试和效果
又到激动人心的实验测试环节了,这里我们主要是关注图片的OCR结果,为此,我这边生成了纯色背景的各色图片。图片生成代码如下:
from PIL import Image, ImageDraw, ImageFont
import glob
import os
import pandas as pd
import random
import uuid
import json
import tqdm
import math
random.seed(100)
def do_create(font_paths, font_sizes, positions, texts, img_color, text_colors, width, height, save_path):
image = Image.new('RGB', (width, height), img_color)
# 创建一个可以在图片上写字的对象
draw = ImageDraw.Draw(image)
for font_path, position, text, text_color, font_size in zip(font_paths, positions, texts, text_colors, font_sizes):
font = ImageFont.truetype(font_path, font_size)
draw.text(position, text, font=font, fill=text_color)
image.save(save_path)
def get_text_and_img_color(white_img_p = 0.75):
# 图片颜色
if random.random() < white_img_p:
# 纯白色
img_color = (255, 255, 255)
text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
else:
img_color = (random.choice(range(256)), random.choice(range(256)), random.choice(range(256)))
text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
while text_color == img_color:
text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
if text_color != img_color:
break
return img_color, text_color
def compute_single_row_img_infos(text, width_max = 1344, height_max = 1344):
font_size_list = list(range(10, 25))
# 图片颜色
img_color, text_color = get_text_and_img_color(white_img_p = 0.75)
# 字体大小和图片页面布局
ft_size = random.choice(font_size_list)
# text 不做切断
height = random.choice(range(200, 500, 10))
width = ft_size * len(text) + random.choice(range(50, 100))
# 0.4的概率截断左右两边的字
if random.random() < 0.4:
if int(len(text) * 0.5) > 1:
candidate_cut = random.choice(range(1, int(len(text) * 0.5)))
if random.random() < 0.5:
text = text[0:-candidate_cut]
else:
text = text[candidate_cut:]
assert width < width_max, "width > width_max"
assert len(text) < 2048
pos_x = random.choice(range(2, width - ft_size * len(text)))
pos_y = random.choice(range(0, height - ft_size))
position = (pos_x, pos_y)
return text, ft_size, width, height, position, img_color, text_color
def create_single_row_words_img():
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/single_row"
if not os.path.exists(path):
os.makedirs(path)
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
with open('./one_row_content.jsonl','r', encoding='utf-8') as f:
lines = f.readlines()
total_count = 0
with open('./solid_bg_imgs/single_row.json', 'w', encoding='utf-8') as f:
for line in tqdm.tqdm(lines[0:], desc="create_single_row_words_img"):
try:
text = json.loads(line)['content']
ft_ps = [random.choice(font_paths)]
text, ft_size, width, height, position, img_color, text_color = compute_single_row_img_infos(text)
save_path = os.path.join(path, f"{total_count}.jpg")
ft_sizes = [ft_size]
positions = [position]
texts = [text]
text_colors = [text_color]
do_create(ft_ps, ft_sizes, positions, texts, img_color, text_colors, width, height, save_path)
temp = {
"context": text,
"path": save_path
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
total_count += 1
except Exception as e:
print(e)
print(f"create_single_row_words_img count {total_count}")
def multi_0_thousand_word():
def compute_imgs_text_label(ft_se, poisitons, texts, width):
label = ""
new_texts = []
for ft, poi, text in zip(ft_se, poisitons, texts):
# 计算字的个数
word_count = int((width - poi[0]) / ft)
text = text[0:word_count]
new_texts.append(text)
label += text + "\n\n"
label = label.strip('\n\n')
return label, texts
width = 500
height = 500
font_size_list = list(range(15, 25))
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/thousand_word"
if not os.path.exists(path):
os.makedirs(path)
df = pd.read_csv("./classification_train_dataset_75W_1130.tsv", sep='\t')
sentences1 = df['sentence1'].values.tolist()
sentences2 = df['sentence2'].values.tolist()
sentences = sentences1 + sentences2
sentences = [ str(ele) for ele in sentences]
sentences = list(set(sentences))
sentences.sort()
print("len(sentences)", len(sentences))
targets = []
text_path = "/data02/yanghuang/workspace/Qwen-VL/datas/海尔热线报装报修_20240327_179672.jsonl"
with open(text_path, 'r',
encoding='utf-8') as reader:
lines = reader.readlines()
for line in lines:
target = json.loads(line)['target']
targets.append(target)
targets = list(set(targets))
targets.sort()
print(f"len(targets) {len(targets)}")
# font_paths = glob.glob("/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/fonts/*")
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
rows = list(range(2, 7))
row_spaces = list(range(20, 40))
with open('./solid_bg_imgs/multi_rows/thousand_word.json', 'w', encoding='utf-8') as f:
total_count = 0
for _ in tqdm.tqdm(range(25000), desc="create_multi_rows_words_img thousand_word"):
row_count = random.choice(rows)
# 不放回采样
ft_ps = random.sample(font_paths, k=row_count)
# 放回采样
ft_se = random.choices(font_size_list, k=row_count)
texts = random.sample(sentences, k=row_count)
poisitons = []
first_row_height = 10
for i in range(row_count):
if i == 0:
poisitons.append((random.choice(range(0, 200)), first_row_height))
else:
poisitons.append(
(random.choice(range(0, 200)), poisitons[i - 1][1] + ft_se[i - 1] + random.choice(row_spaces)))
label, texts = compute_imgs_text_label(ft_se, poisitons, texts, width)
save_path = os.path.join(path, f"{total_count}.jpg")
img_color = (255, 255, 255)
text_colors = [(random.choice(range(255)), random.choice(range(255)), random.choice(range(255))) for _ in
range(row_count)]
temp = {
"context": label,
"path": save_path,
"textsource": "thousand_word"
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count +=1
# 海尔热线报装报修的生成
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/海尔热线报装报修"
if not os.path.exists(path):
os.makedirs(path)
for index in tqdm.tqdm(range(25000), desc=f"create_multi_rows_words_img 海尔热线报装报修"):
target = targets[index].split('\n\n')
row_count = len(target)
# 不放回采样
ft_ps = [random.choice(font_paths)] * row_count
# 放回采样
ft_se = [random.choice(font_size_list)] * row_count
texts = target
poisitons = []
first_row_height = 10
posi_x = random.choice(range(0, 200))
space = random.choice(row_spaces)
for i in range(row_count):
if i == 0:
poisitons.append((posi_x, first_row_height))
else:
poisitons.append(
(random.choice(range(0, 200)), poisitons[i - 1][1] + ft_se[i - 1] + random.choice(row_spaces)))
label, texts = compute_imgs_text_label(ft_se, poisitons, texts, width)
save_path = os.path.join(path, f"{total_count}.jpg")
img_color = (255, 255, 255)
text_colors = [(random.choice(range(255)), random.choice(range(255)), random.choice(range(255))) for _ in
range(row_count)]
temp = {
"context": label,
"path": save_path,
"textsource": "海尔热线报装报修"
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count += 1
def multi_1_poet():
width_max = 600
height_max = 800
font_size_list = list(range(10, 35))
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/poet"
if not os.path.exists(path):
os.makedirs(path)
poets_map = {}
poets_index = []
poet_paths = glob.glob("/data02/yanghuang/datasets/LLM_datasets/chinese-poetry/全唐诗/简体/*.json")
index = 0
total = 0
for poet_path in tqdm.tqdm(poet_paths, desc="get_poets"):
with open(poet_path, 'r', encoding='utf-8') as f:
poets = json.load(fp=f)
for poet in poets:
paragraphs = poet['paragraphs']
total += 1
if 2 <= len(paragraphs) <= 4 and poet['title'] != "":
poets_map[index] = poet
poets_index.append(index)
index += 1
print(total)
print(len(poets_index))
def compute_poets_img_infos(poets, font_size_list, font_paths, width, height):
#行间距
ratio_sinle = random.choice([1.0,1.1,1.2,1.3,1.4,1.5])
ratio_multi = random.choice([3,4,5])
ft_ps = []
ft_se = []
poisitons = []
texts = []
text_colors = []
label = ""
if random.random() < 0.6:
# 每首诗 字体、颜色、布局统一
ft_p = random.choice(font_paths)
ft_s = random.choice(font_size_list)
# 图片颜色
img_color, text_color = get_text_and_img_color(white_img_p=0.75)
counts = [ len(poet['paragraphs'][0]) for poet in poets]
counts.extend([len(poet['title'])+len(poet['author'])+len("——") for poet in poets])
word_count_max = max(counts)
posi_x = random.choice(range(0, width - word_count_max * ft_s-10))
for index, poet in enumerate(poets):
author = poet['author']
paragraphs = poet['paragraphs']
title = poet['title']
text = f"{title}——{author}"
if index == 0:
position = (posi_x, 10)
else:
position = (posi_x, poisitons[-1][1] + ft_s*ratio_multi)
ft_ps.append(ft_p)
ft_se.append(ft_s)
poisitons.append(position)
texts.append(text)
text_colors.append(text_color)
label += text + '\n'
for sen in paragraphs:
text = sen
position = (posi_x, poisitons[-1][1]+int(ft_s*ratio_sinle))
ft_ps.append(ft_p)
ft_se.append(ft_s)
poisitons.append(position)
text_colors.append(text_color)
texts.append(text)
label += text + '\n'
label = label.strip('\n') + '\n\n'
label = label.strip("\n\n")
else:
for index, poet in enumerate(poets):
author = poet['author']
paragraphs = poet['paragraphs']
title = poet['title']
text = f"{title}——{author}"
# 字体和尺寸
ft_p = random.choice(font_paths)
ft_s = random.choice(font_size_list)
# 图片颜色
img_color, text_color = get_text_and_img_color(white_img_p=0.75)
word_count_max = max([len(poet['paragraphs'][0]), len(poet['title'])+len(poet['author'])+len("——")])
posi_x = random.choice(range(0, width - word_count_max * ft_s - 10))
if index == 0:
position = (posi_x, 10)
else:
position = (posi_x, poisitons[-1][1] + ft_se[-1]*ratio_multi)
ft_ps.append(ft_p)
ft_se.append(ft_s)
poisitons.append(position)
texts.append(text)
text_colors.append(text_color)
label += text + '\n'
for sen in paragraphs:
text = sen
position = (posi_x, poisitons[-1][1]+ int(ft_s*ratio_sinle))
ft_ps.append(ft_p)
ft_se.append(ft_s)
poisitons.append(position)
text_colors.append(text_color)
texts.append(text)
label += text + '\n'
label = label.strip('\n') + '\n\n'
label = label.strip("\n\n")
assert poisitons[-1][1] < height, "beyond img height"
height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20,150))
return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height
total_count = 0
with open('./solid_bg_imgs/multi_rows/poet.json', 'w', encoding='utf-8') as f:
while len(poets_index) > 0:
if random.random() < 0.3:
choiced_indexs = random.sample(poets_index, k = 1)
elif random.random() <0.6:
choiced_indexs = random.sample(poets_index, k = 2 if len(poets_index) >= 2 else 1)
else:
choiced_indexs = random.sample(poets_index, k = 3 if len(poets_index) >= 3 else len(poets_index))
poets_index = list(set(poets_index) - set(choiced_indexs))
poets_index.sort()
try:
poets = [poets_map[ele] for ele in choiced_indexs]
ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_poets_img_infos(poets,
font_size_list,
font_paths,
width_max, height_max)
save_path = os.path.join(path, f"{total_count}.jpg")
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count += 1
temp = {
"context": label,
"path": save_path,
"textsource": "poet"
}
print(f"\rtotal_count: {total_count}", end="")
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
except Exception as e:
continue
if len(poets_index) == 0:
break
print("")
print(f"total_count: {total_count}")
def multi_2_dialog():
paths = [
"/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/达能_20240403_80508_2_aug_20240403_1414_80508.jsonl",
"/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/美素工单小结_20240329_2687_aug_20240329_1710_2687.jsonl",
"/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/太平寿险工单总结_20240605_60422_aug_20240605_1129_60422.jsonl",
"/data02/yanghuang/workspace/llm_platform/trainer/dataset/general/圆通总结摘要_20240614_14036.jsonl",
]
dialogs = []
for path in paths:
with open(path, 'r', encoding='utf-8') as f:
datas = f.readlines()
print(f"{path}--{len(datas)}")
for data in datas:
dialog = json.loads(data)['context'].split("\n\n")[0]
if 2048 >= len(dialog) >= 1 and dialog != "":
dialogs.append(dialog)
print(f"len(dialogs) {len(dialogs)}")
dialogs = list(set(dialogs))
dialogs.sort()
print(f"len(dialogs) {len(dialogs)}")
width_max = 1300
height_max = 1300
font_size_list = list(range(10, 35))
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/business_dialog"
if not os.path.exists(path):
os.makedirs(path)
def compute_dialog_img_infos(dialog, font_size_list, font_paths, width, height):
ratio_sinle = random.choice([1.0, 1.1, 1.2, 1.3, 1.4, 1.5])
ft_ps = []
ft_se = []
poisitons = []
texts = []
text_colors = []
label = ""
dialog = dialog.split('\n')
# 每首诗 字体、颜色、布局统一
ft_p = random.choice(font_paths)
ft_s = random.choice(font_size_list)
# 图片颜色
img_color, text_color = get_text_and_img_color(white_img_p=0.75)
posi_x = random.choice(range(10, 50))
row_count = 1100//int(ft_s*ratio_sinle)
for index, text in enumerate(dialog[0:row_count]):
ft_ps.append(ft_p)
ft_se.append(ft_s)
text_colors.append(text_color)
if index == 0:
position = (posi_x, 10)
else:
position = (posi_x, poisitons[-1][1] + int(ft_s*ratio_sinle))
poisitons.append(position)
text = text[: int((width - posi_x)/ft_s)]
assert width >= posi_x + len(text) * ft_s ,"sentence too long"
texts.append(text)
label += text +'\n'
assert poisitons[-1][1] < height, "turns too long"
height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20, 100))
return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height
total_count = 0
with open('./solid_bg_imgs/multi_rows/business_dialog.json', 'w', encoding='utf-8') as f:
for dialog in tqdm.tqdm(dialogs[0:], desc="multi_2_dialog"):
try:
ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_dialog_img_infos(
dialog, font_size_list, font_paths, width_max, height_max)
save_path = os.path.join(path, f"{total_count}.jpg")
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count += 1
temp = {
"context": label,
"path": save_path,
"textsource": "business_dialog"
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
except Exception as e:
print(e)
continue
print(f"\ntotal_count: {total_count}")
def compute_multi_paragraph_img_infos(content, font_size_list, font_paths, width_max, height_max):
# 行间距
ratio_sinle = random.choice([1.0, 1.1, 1.2, 1.3, 1.4, 1.5])
ratio_multi = random.choice([3, 4, 5])
token_size_max = int(math.sqrt(width_max*height_max/2.1/len(content)))
assert token_size_max > font_size_list[0], "font size too small"
ft_s = random.choice(font_size_list)
while ft_s > token_size_max:
ft_s = random.choice(font_size_list)
if ft_s <= token_size_max:
break
ft_ps = []
ft_se = []
poisitons = []
texts = []
text_colors = []
label = ""
row_index = 0
if random.random() < 0.65:
# 字体种类、颜色统一
ft_p = random.choice(font_paths)
img_color, text_color = get_text_and_img_color(white_img_p=0.75)
temps = content.split('\n\n')
for temp in temps:
temp = temp.split('\n')
for ele in temp:
each_row_token_count = (width_max - 10-2*ft_s) // ft_s
rows = 1 + (len(ele)+2)//each_row_token_count
new_paragraph = True
for row in range(rows):
start = row * each_row_token_count
end = (row + 1) * each_row_token_count
text = ele[start:end]
label += text + "\n"
if row_index == 0:
position = (10+2*ft_s, 10)
else:
if new_paragraph:
position = (10+2*ft_s, poisitons[-1][1] + ft_s * ratio_multi)
else:
position = (10, poisitons[-1][1] + int(ft_s * ratio_sinle))
row_index += 1
ft_se.append(ft_s)
ft_ps.append(ft_p)
text_colors.append(text_color)
texts.append(text)
poisitons.append(position)
new_paragraph = False
label = label.strip("\n") + "\n\n"
label = label.strip('\n\n')
else:
temps = content.split('\n\n')
for temp in temps:
temp = temp.split('\n')
for ele in temp:
#每个段落一种字体和颜色
ft_p = random.choice(font_paths)
img_color, text_color = get_text_and_img_color(white_img_p=0.75)
each_row_token_count = (width_max - 10-2*ft_s) // ft_s
rows = 1 + (len(ele)+2)//each_row_token_count
new_paragraph = True
for row in range(rows):
start = row * each_row_token_count
end = (row + 1) * each_row_token_count
text = ele[start:end]
label += text + "\n"
if row_index == 0:
position = (10+2*ft_s, 10)
else:
if new_paragraph:
position = (10+2*ft_s, poisitons[-1][1] + ft_s * ratio_multi)
else:
position = (10, poisitons[-1][1] + int(ft_s * ratio_sinle))
row_index += 1
ft_se.append(ft_s)
ft_ps.append(ft_p)
text_colors.append(text_color)
texts.append(text)
poisitons.append(position)
new_paragraph = False
label = label.strip("\n") + "\n\n"
label = label.strip('\n\n')
assert poisitons[-1][1] < height_max, "beyond img height"
height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20, 150))
return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width_max, height
def multi_3_gaokao_comprehension():
width_max = 1300
height_max = 1300
font_size_list = list(range(5, 35))
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
paths = glob.glob('/data02/yanghuang/datasets/LLM_datasets/VGaokao-阅读理解/data/*/*.json')
contents = []
for path in paths:
with open(path, 'r', encoding='utf-8') as f:
datas = json.load(fp=f)['data']
for data in datas:
if "context" in data:
content = data['context']
if 100 <= len(content) <= 2048:
contents.append(content)
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/gaokao_comprehension"
if not os.path.exists(path):
os.makedirs(path)
total_count = 0
with open('./solid_bg_imgs/multi_rows/gaokao_comprehension.json', 'w', encoding='utf-8') as f:
for dialog in tqdm.tqdm(contents[0:], desc="multi_3_gaokao_comprehension",ncols=90):
try:
ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_multi_paragraph_img_infos(
dialog, font_size_list, font_paths, width_max, height_max)
save_path = os.path.join(path, f"{total_count}.jpg")
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count += 1
temp = {
"context": label,
"path": save_path,
"textsource": "gaokao_comprehension"
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
except Exception as e:
print(e)
continue
print(f"\ntotal_count: {total_count}")
def multi_4_glm_common():
contents = []
general_paths = glob.glob("/data02/yanghuang/workspace/llm_platform/trainer/dataset/general/glm2*.jsonl")
for path in general_paths:
with open(path,'r',encoding='utf-8') as f:
lines = f.readlines()
desc = path.split('/')[-1].replace(".jsonl","")
for line in tqdm.tqdm(lines, desc=f"{desc}"):
line = json.loads(line)
content = line['context'] + '\n'+ line['target']
if 100 <= len(content) <= 2048:
contents.append(content)
contents = random.sample(contents,k=650000)
path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/glm_common"
if not os.path.exists(path):
os.makedirs(path)
width_max = 1300
height_max = 1300
font_size_list = list(range(5, 35))
font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
total_count = 0
with open('./solid_bg_imgs/multi_rows/glm_common.json', 'w', encoding='utf-8') as f:
for dialog in tqdm.tqdm(contents[0:], desc="glm_common", ncols=90):
try:
ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_multi_paragraph_img_infos(
dialog, font_size_list, font_paths, width_max, height_max)
save_path = os.path.join(path, f"{total_count}.jpg")
do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
height=height, save_path=save_path)
total_count += 1
temp = {
"context": label,
"path": save_path,
"textsource": "glm_common"
}
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
except Exception as e:
print(e)
continue
print(f"\ntotal_count: {total_count}")
def create_multi_rows_words_img():
multi_0_thousand_word()
multi_1_poet()
multi_2_dialog()
multi_3_gaokao_comprehension()
multi_4_glm_common()
def main():
create_single_row_words_img()
create_multi_rows_words_img()
if __name__ == '__main__':
main()
主要是针对不同的文本设计合适的版面来生成多行或者单行的图片(上面的代码中一些文本文件和字体文件需要替换为自己本地路径才能运行成功),示例如下:
单行文本图片
多行文本图片1
多行文本图片2
对上述生成的图片使用没有微调的QwenVL和MiniCPM-Llama3-V-2_5进行推理,计算样本准确率,样本字符编辑距离,准确率、错误率等。
import Levenshtein
def wer_ccor_compute(ref,pre):
substitution = 0
deletion = 0
insertion = 0
results = Levenshtein.editops(ref, pre)
for r in results:
if "replace" in r:
substitution += 1
elif "delete" in r:
deletion += 1
else:
insertion += 1
wer = (deletion+insertion+insertion) / len(ref)
ccor = (len(ref)-deletion-substitution)/len(ref)
ld = substitution + deletion + insertion
return substitution, deletion, insertion, ld, wer, ccor
上述代码使用Levenshtein库来计算ref和pre字符串的差异,替换、删除、插入的数量,wer、ccor和ld(编辑距离越小越好)
未微调的效果
微调后的效果
可以看到微调前两个模型在纯色背景图片上的识别效果不太好,说明OCR能力有待提高;MiniCPM-Llama3-V-2_5的效果微调前或者后都要由于QwenVL。微调后MiniCPM-Llama3-V-2_5样本准确率从0.419提升到0.81,QwenVL从0.15提升到0.75;多行文本的图片角度来看,MiniCPM-Llama3-V-2_5从0.0125提升到0.63,提升巨大,QwenVL多行文本图片准确率也从0.0095提升到了0.3845,但是要远低于前者。
demo展示
前端页面vue3实现,后端采用aiohttp实现流式推理,后端代码:
import os
import pytomlpp as toml
config = toml.load('config.toml')
os.environ['CUDA_VISIBLE_DEVICES'] = config['common']['device']
import asyncio
from aiohttp import web
import time
import socket
import logging
import pandas as pd
import json
from aiohttp_cors import setup, ResourceOptions
class LLMVLInfer(object):
def __init__(self, config):
self.model_type = config['model']['model_type']
self.model_path = config['model']['model_path']
self.adapter_path = config['model']['adapter_path']
self.model = globals()[self.model_type](self.model_path, self.adapter_path)
self.logger = self.create_logger()
self.df = pd.read_csv(config['data']['test_file'])
def inference(self, prompt, img_base64_list):
response = self.model.chat(prompt, img_base64_list=img_base64_list)
return response
async def stream_generator(self,prompt, img_base64_list):
stream_gen = self.model.chat(prompt, img_base64_list=img_base64_list, stream=True)
for ele in stream_gen:
yield ele
def component_prompt(self, content, img):
img_base64_list = None
if self.model_type == "MiniCPMV2_5":
msgs = [{"role": "user", "content": content}]
prompt = {"image": img, "question": json.dumps(msgs)}
else:
img_path = ""
prompt = self.model.tokenizer.from_list_format([
{'image': img_path}, # Either a local path or an url
{'text': f'{content}'},
])
img_base64_list = [img]
return prompt, img_base64_list
async def post(self, request:web.Request):
req = await request.json()
id = req['id']
img = req['params']['data']['image']
content = req['params']['data']['prompt']
prompt, img_base64_list = self.component_prompt(content, img)
start = time.time()
try:
# result = await asyncio.get_event_loop().run_in_executor(None, self.inference, prompt)
result = await asyncio.to_thread(self.inference, prompt, img_base64_list)
end = time.time()
except Exception as e:
self.logger.info(f'id: {id} inference fail: {e}')
return web.json_response(self.build_fail_resp(id_=id, code=-1, msg=f"{e}"))
tokens = len(self.model.tokenizer.encode(result))
cost_time = (end-start)*1000
speed = tokens/(end-start)
send_data = self.build_resp_success(id, result, tokens, cost_time, speed)
self.logger.info(json.dumps(send_data,ensure_ascii=False))
return web.json_response(send_data)
async def get_current_page_datas(self,request:web.Request):
req = await request.json()
self.logger.info(f"{req}")
current_page = req['params']['currentPage']
page_size = req['params']['pageSize']
start = (current_page -1) * page_size
end = current_page*page_size
df = self.df[start:end]
items = []
for _, row in df.iterrows():
temp = {'imgurl': None, 'imgpath': row['path'], 'percentage': 0, 'result': [row['输入文本'], row['识别文本']],
'levenshtein_distance': row['levenshtein_distance'], 'isright': row['是否正确'],"imgshow":False}
items.append(temp)
result = {
"itemtotal":len(self.df),
"items":items
}
self.logger.info(f"send current page [page-{current_page}] datas {result}")
return web.json_response({
"status":0,
"result":result
})
async def get_img(self, request:web.Request):
req = await request.json()
self.logger.info(f"{req}")
imgpath = req['params']['imgpath']
with open(imgpath, 'rb') as f:
image_data = f.read()
headers = {'Content-Type': 'image/jpeg'} # 根据图片类型更改这里的 MIME 类型
return web.Response(body=image_data, headers=headers)
async def getmertics(self, request:web.Request):
df = self.df
acc = len(df[df['是否正确'] == True]) / len(df)
total_ld_count = 0
total_word_count = 0
deletions = 0
substitutions = 0
for _, row in df.iterrows():
total_word_count += len(row['输入文本'])
total_ld_count += int(row['levenshtein_distance'])
deletions +=int(row['D'])
substitutions += int(row['S'])
wer = total_ld_count / total_word_count
war = (total_word_count-deletions-substitutions)/total_word_count
result = {
"acc":acc,
"wer":wer,
"war":war
}
self.logger.info(f"{result}")
return web.json_response(result)
async def stream_infer(self, request:web.Request):
ws = web.WebSocketResponse()
await ws.prepare(request)
# 接收客户端的消息
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
data = msg.json()
self.logger.info(f"ws request: {data}")
# 解析文本和图片的Base64编码
content = data.get('text', '')
img = data.get('image_base64_str', '')
prompt, img_base64_list = self.component_prompt(content, img)
try:
stream_generator = self.stream_generator(prompt, img_base64_list)
async for token in stream_generator:
if token != "":
await ws.send_str(token)
finally:
self.logger.info(f"ws close")
await ws.close()
return ws
async def handle(self, request:web.Request):
return web.FileResponse('./dist/index.html')
def build_fail_resp(self, id_: int, code: int, msg: str):
return web.json_response({
'id': id_,
'jsonrpc': '2.0',
'ret': code,
'result': {
"error_info": msg
}
})
def build_resp_success(self, id, answer, tokens, cost_time, speed):
rsp = {
"id": id,
"jsonrpc": "2.0",
"ret": 0,
"result": {
"chatInfo": {
"answer": answer,
"elements": []
},
"tokens": tokens,
"cost_time": f"{cost_time} ms",
"speed": f"{speed} tokens/s"
}
}
return rsp
def create_logger(self):
log_level = config["log"]["log_level"]
log_path = "./logs/server.log"
logger = logging.getLogger(__name__)
logger.setLevel(level=log_level)
formatter = logging.Formatter("%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s")
# 创建一个handler,用于写入日志文件,按大小覆盖
# file_handler = logging.handlers.RotatingFileHandler(filename=log_path, maxBytes=838860800, backupCount=20, encoding='utf-8')
# 按日期覆盖
file_handler = logging.handlers.TimedRotatingFileHandler(filename=log_path, when='D', interval=1,
encoding='utf-8')
file_handler.setFormatter(formatter)
file_handler.setLevel(level=log_level)
logger.addHandler(file_handler)
# 创建一个handler,用于将日志输出到控制台
console = logging.StreamHandler()
console.setLevel(level=log_level)
console.setFormatter(formatter)
logger.addHandler(console)
return logger
async def init_app():
llmvl_infer = LLMVLInfer(config)
app = web.Application()
app.add_routes([
web.post('/nlp',llmvl_infer.post),
web.post('/CurrentPageDatas', llmvl_infer.get_current_page_datas),
web.post('/fetchimg', llmvl_infer.get_img),
web.get('/getmertics', llmvl_infer.getmertics),
web.get("/ws",llmvl_infer.stream_infer),
web.get("/", llmvl_infer.handle),
web.static('/', path='./dist/', name='static')
])
cors = setup(app, defaults={
"*": ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
for route in list(app.router.routes()):
cors.add(route)
return app
if __name__ == '__main__':
if not os.path.exists("./logs"):
os.makedirs("./logs")
bind_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0)
bind_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
bind_socket.bind(('0.0.0.0', 2222))
web.run_app(init_app(), sock=bind_socket)
页面截图如下:
四、思考
MiniCPM-Llama3-V-2_5参数量比QwenVL的参数量还要少一点,而效果再实际体验中却好很多,我认为预训练数据的影响最大,猜测MiniCPM-Llama3-V-2_5训练的图片和文本数据更多质量更高;另外一个就是MiniCPM-Llama3-V-2_5对输入图片的前置处理和压缩方式保留的图片信息更多,img占用的token数量也更多,这个也是效果比较好的原因之一;缺陷呢我觉得就是MiniCPM-Llama3-V-2_5对图片tokenize占用过多的token,处理多图甚至多轮多图的能力天然的并不如QwenVL系列那么灵活和低成本。
还有一个就是QwenVL的缺陷,它把图片路径和文本prompt拼接一起输入,然后再解码出来,也不是解耦的那种方式,对模型部署的性能有一定的影响。客户请求传入的是一个img,QwenVL需要把图片保存在本地得到一个img_path,才能进行推理。推理过程中有再读了一次图片,明显这样的设计不合理。需要优化,可以直接把图片地址id虚拟化,同时把图片和id映射后直接输入图片进去,后端不保存img获取img_path,而是根据虚拟化的图片地址id和id映射字典直接获取图片进行推理。
参考文章
Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond
更多推荐
所有评论(0)