OmniSearch Qwen2+VL+7B
【代码】OmniSearch Qwen2+VL+7B。
·
配环境
Qwen网页需要安装的
pip install qwen-vl-utils -i https://mirrors.aliyun.com/pypi/simple/
pip install transformers -i https://mirrors.aliyun.com/pypi/simple/
- 注意python 必须是3.11.9以上,否则 serpapi 报错找不到goolge-search
- 虚拟环境,进去后pip install 自己需要安装的包,直接用python命令,就是用的python3.11
python3.11 -m venv myenv
source myenv/bin/activate
deactivate
pip install requests==2.32.3 -i https://mirrors.aliyun.com/pypi/simple/ # 加速安装
集成代码部分,OmniSearch main.py
import threading
from concurrent.futures import ThreadPoolExecutor
import os
import json
import argparse
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
# 加载 Qwen2-VL-7B 模型和处理器
model = Qwen2VLForConditionalGeneration.from_pretrained(
"/dfs/share-read-write/zhangyao22/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto", # 自动选择设备并加载模型
)
processor = AutoProcessor.from_pretrained(
"/dfs/share-read-write/zhangyao22/Qwen2-VL-7B-Instruct"
) # 加载处理器,用于处理输入的文本和图像
# 定义线程锁,确保多线程写文件时不会出现竞态条件
write_lock = threading.Lock()
# 线程安全的写入文件函数
def safe_write(file_path, data):
with write_lock: # 使用锁保护文件写入
with open(file_path, "a", encoding="utf-8") as f:
f.write(
json.dumps(data, ensure_ascii=False) + "\n"
) # 将数据以 JSON 格式写入文件
# 处理数据集中的每一项
def process_item(item, conversation_manager, meta_save_path, dataset_name):
input_question = item["question"] # 获取问题
idx = item["question_id"] # 获取问题 ID
image_url = item["image_url"] # 获取图片 URL
# 准备模型的输入消息,包含文本和图像
messages = [
{
"role": "user", # 设置角色为用户
"content": [
{"type": "image", "image": image_url}, # 图像内容
{"type": "text", "text": input_question}, # 问题文本
],
}
]
# 使用处理器对输入进行预处理
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) # 应用聊天模板
image_inputs, video_inputs = process_vision_info(messages) # 处理图像和视频信息
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt", # 将数据转为 PyTorch 张量
)
# 将输入数据移到 GPU
inputs = inputs.to("cuda")
# 使用模型生成回答
generated_ids = model.generate(
**inputs, max_new_tokens=128
) # 生成最大 128 个新 token
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
] # 修剪生成的 ID,移除输入部分的 token
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
) # 解码生成的输出文本
# 将生成的答案保存到 item 中
item["prediction"] = output_text[0]
# 保存结果到指定路径
output_path = os.path.join(meta_save_path, dataset_name, "output_from_gpt4v.jsonl")
safe_write(output_path, item)
# 主函数,负责运行数据集处理任务
def main(test_dataset, dataset_name, meta_save_path):
# 读取数据集文件
with open(test_dataset, "r", encoding="utf-8") as f:
datas = [
json.loads(line) for line in f.readlines()
] # 将每一行 JSON 数据加载为字典
# 如果输出文件已存在,过滤掉已处理的数据项
output_path = os.path.join(meta_save_path, dataset_name, "output_from_gpt4v.jsonl")
if os.path.exists(output_path):
with open(output_path, "r") as fin:
done_id = [
json.loads(data)["question_id"] for data in fin.readlines()
] # 读取已处理的 question_id
datas = [
data for data in datas if data["question_id"] not in done_id
] # 过滤掉已处理的项
# 创建保存结果的文件夹
save_path = os.path.join(meta_save_path, dataset_name, "search_images_gpt4v")
os.makedirs(save_path, exist_ok=True)
# 创建 conversation_manager(假设你已经定义了一个 `ConversationManager` 类,未显示在代码中)
conversation_manager = ConversationManager(
qa_agent=None, dataset_name=dataset_name, save_path=save_path
)
# 使用 ThreadPoolExecutor 进行并行处理
with ThreadPoolExecutor(max_workers=1) as executor:
futures = [
executor.submit(
process_item, item, conversation_manager, meta_save_path, dataset_name
)
for item in datas # 为每个数据项创建一个异步任务
]
for future in futures:
future.result() # 等待所有任务完成
# 解析命令行参数并运行主函数
if __name__ == "__main__":
# 设置命令行参数
parser = argparse.ArgumentParser(description="Run dataset")
parser.add_argument("--test_dataset", type=str, required=True, help="数据集路径")
parser.add_argument("--dataset_name", type=str, required=True, help="数据集名称")
parser.add_argument("--meta_save_path", type=str, required=True, help="保存路径")
# 解析参数
args = parser.parse_args()
# 调用主函数并传递参数
main(args.test_dataset, args.dataset_name, args.meta_save_path)
# 最后执行代码
python main.py --test_dataset /dfs/data/OmniSearch-main/dataset/DynVQA_zh/DynVQA_zh.202406.jsonl --dataset_name MyDataset --meta_save_path /dfs/data/OmniSearch-main/src
更多推荐
已为社区贡献1条内容
所有评论(0)