1. 项目概述
- 目标:在国产信创硬件上训练长文本分类模型,并部署 API 提供推理服务
- 任务类型:多类别/二分类 NLP 问题
- 输入数据:长文本(如 2000+ token)
- 输出:文本类别预测
- 硬件环境:
- 2 × Ascend 910B2 NPU
- 鲲鹏 ARM64 CPU
- 昆仑信创操作系统(如 openEuler / 麒麟)
- 软件环境:
Python >= 3.9
PyTorch 2.2.1(Ascend 镜像):
pipinstalltorch==2.2.1 -f https://ascend-pytorch-mirror.huawei.com/whl/torch/Transformers
NumPy, pandas, scikit-learn
2. 数据处理
2.1 文本切分
- 长文本超过 BERT 最大长度(如 512)时,使用BERT Split:
- 将文本按句子或固定长度切分为多个片段
- 每个片段通过 BERT 编码
- 拼接或平均片段的 hidden states 作为文本表示
- 可选:文本重叠切分,保证上下文连续性
2.2 数据集示例
importpandasaspdfromsklearn.model_selectionimporttrain_test_split df=pd.read_csv('long_text_dataset.csv')# columns: text, labeltrain_texts,val_texts,train_labels,val_labels=train_test_split(df['text'].tolist(),df['label'].tolist(),test_size=0.1,random_state=42)2.3 Tokenizer
fromtransformersimportBertTokenizer tokenizer=BertTokenizer.from_pretrained("bert-base-chinese")defencode_texts(texts,max_len=512):encoded_list=[]fortextintexts:# 分段处理segments=[text[i:i+max_len]foriinrange(0,len(text),max_len)]encoded_segments=[tokenizer(s,padding='max_length',truncation=True,return_tensors='pt')forsinsegments]encoded_list.append(encoded_segments)returnencoded_list3. 模型设计:BERTSplitLSTM
3.1 结构说明
BERT Encoder:
- 每个文本片段使用 BERT 编码
- 输出
[CLS]或最后隐藏层
片段合并:
- 将片段向量按顺序拼接或送入 LSTM
LSTM:
- 捕捉跨片段的长文本上下文
- 双向 LSTM 可选
分类层:
- 全连接 + softmax
- 输出文本类别
3.2 PyTorch 示例
importtorchimporttorch.nnasnnfromtransformersimportBertModelclassBERTSplitLSTM(nn.Module):def__init__(self,bert_model_name='bert-base-chinese',lstm_hidden=256,num_classes=10):super().__init__()self.bert=BertModel.from_pretrained(bert_model_name)self.lstm=nn.LSTM(input_size=self.bert.config.hidden_size,hidden_size=lstm_hidden,num_layers=1,batch_first=True,bidirectional=True)self.fc=nn.Linear(2*lstm_hidden,num_classes)defforward(self,segments_batch):# segments_batch: list of segments tensors, shape [batch, seg_len, hidden_size]segment_outputs=[]forsegmentsinsegments_batch:seg_embs=[]forseginsegments:output=self.bert(**seg).last_hidden_state[:,0,:]# CLS tokenseg_embs.append(output)seg_embs=torch.stack(seg_embs,dim=1)# [batch, n_segments, hidden_size]lstm_out,_=self.lstm(seg_embs)final_output=lstm_out[:,-1,:]segment_outputs.append(final_output)returnself.fc(torch.cat(segment_outputs,dim=0))4. 训练配置
损失函数:
CrossEntropyLoss优化器:
AdamW(带权重衰减)学习率策略:线性 warmup + decay
批大小:根据显存,双卡 910B2 可尝试 batch=4~8
梯度累积:长文本可使用梯度累积降低显存占用
混合精度训练:
scaler=torch.cuda.amp.GradScaler()
4.1 训练示例
fromtorch.utils.dataimportDataLoader train_loader=DataLoader(train_dataset,batch_size=2,shuffle=True)forepochinrange(epochs):forbatchintrain_loader:optimizer.zero_grad()withtorch.cuda.amp.autocast():outputs=model(batch['segments'])loss=criterion(outputs,batch['labels'])scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()5. 模型部署
5.1 模型保存
torch.save(model.state_dict(),"bert_split_lstm_finetune.pt")5.2 转换 OM(Ascend)
# 导出 ONNXpython export_to_onnx.py --model_path bert_split_lstm_finetune.pt --output bert_split_lstm.onnx# ONNX → OMatc --model=bert_split_lstm.onnx --framework=5--output=bert_split_lstm.om --soc_version=Ascend910B2 --input_shape="input_ids:1,512"5.3 API 部署
- 方法:
- 使用 FastAPI
- 支持多进程 + 多线程 + 批量请求
fromfastapiimportFastAPIimporttorch app=FastAPI()model=load_om_model("bert_split_lstm.om",device='ascend',card_ids=[0,1])@app.post("/predict")asyncdefpredict(text:str):segments=encode_texts([text])pred=model(segments)return{"label":pred.argmax(dim=-1).item()}6. 性能优化
- 多卡并行:910B2 ×2 NPU
- 批量推理:增加吞吐
- 多线程/异步:利用 CPU 做数据预处理
- 量化/半精度训练:降低显存,提升速度
- 预热模型:推理前跑几次 batch
7. 验证与上线
- 小规模文本测试模型准确性
- 大批量文本测试吞吐和延迟
- 监控 NPU 显存、CPU、推理延迟