一、导读
今日,智谱AI发布了最新的代码模型CodeGeeX2-6B(https://mp.weixin.qq.com/s/qw31ThM4AjG6RrjNwsfZwg),并已在魔搭社区开源。
CodeGeeX2作为多语言代码生成模型CodeGeeX的第二代模型,使用ChatGLM2架构注入代码实现,具有多种特性,如更强大的代码能力、更优秀的模型特性、更全面的AI编程助手和更开放的协议等。
本文提供了CodeGeeX2的微调教程,希望更多开发者基于开源和数据集微调CodeGeeX2,共同创造AI生态。期待通过这一开源,让CodeGeeX2能成为每一位程序员的编程助手。
魔搭开源链接:https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary
二、环境配置与安装
本文使用ModelScope的Notebook免费环境测试,python>=3.8
升级ModelScope环境:
ModelScope需要升级到github上最新的master版本(预计8月1号发布版本),进入Notebook的Terminal环境:
更新ModelScope版本:
git clone https://github.com/modelscope/modelscope.git
cd modelscope
pip install .
三、模型链接及下载
CodeGeeX2-6B
模型链接:https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary
使用notebook进行模型weights下载(飞一样的速度,可以达到百兆每秒):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download('ZhipuAI/codegeex2-6b', revision='v1.0.0')
四、模型推理
CodeGeeX2-6B推理代码,版本更新前,需要在Notebook的Terminal里面执行
import torch
from modelscope import AutoModel, AutoTokenizer
model_id = 'ZhipuAI/codegeex2-6b'
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, device_map={'': 'cuda:0'}, # or device_map='auto'
torch_dtype=torch.bfloat16, trust_remote_code=True)
model = model.eval()
# remember adding a language tag for better performance
prompt = "# language: python\n# write a bubble sort function\n"
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=256)
response = tokenizer.decode(outputs[0])
print(response)
推理运行显存:13G
五、效果体验
体验了一下使用python解决八皇后问题,效果还是不错的!
>>> prompt = "# language: python\n# solve eight queen problem\n"
>>> inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
>>> outputs = model.generate(inputs, max_length=512)
>>> response = tokenizer.decode(outputs[0])
>>> print(response)
# language: python
# solve eight queen problem
def conflict(state, nextX):
nextY = len(state)
for i in range(nextY):
if abs(state[i] - nextX) in (0, nextY - i):
return True
return False
def queens(num=8, state=()):
for pos in range(num):
if not conflict(state, pos):
if len(state) == num - 1:
yield (pos,)
else:
for result in queens(num, state + (pos,)):
yield (pos,) + result
if __name__ == "__main__":
print(list(queens(8)))
使用C++解决快排问题
>>> prompt = "// language: C++\n// write a quick sort function\n"
>>> inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
>>> outputs = model.generate(inputs, max_length=512)
>>> response = tokenizer.decode(outputs[0])
>>> print(response)
// language: C++
// write a quick sort function
#include <iostream>
#include <vector>
using namespace std;
void quickSort(vector<int> &arr, int start, int end) {
if (start >= end) {
return;
}
int pivot = arr[start];
int left = start;
int right = end;
while (left < right) {
while (left < right && arr[right] >= pivot) {
right--;
}
while (left < right && arr[left] <= pivot) {
left++;
}
if (left < right) {
swap(arr[left], arr[right]);
}
}
swap(arr[left], arr[start]);
quickSort(arr, start, left - 1);
quickSort(arr, left + 1, end);
}
int main() {
vector<int> arr = {5, 3, 4, 1, 2, 8, 7, 9, 6, 0};
quickSort(arr, 0, arr.size() - 1);
for (int i = 0; i < arr.size(); i++) {
cout << arr[i] << " ";
}
return 0;
}
所有评论(0)