计算vlm模型的ppl损失。
代码:
fromtransformersimportQwen2VLForConditionalGeneration,AutoProcessorimporttorchfromtorch.nnimportCrossEntropyLossfromPILimportImage# 配置DEVICE="cuda:0"MODEL_NAME="/data1/chenjun/huf/Qwen2-VL-2B-Instruct"IMAGE_SIZE=384defresize_image(path,max_side=384):"""调整图片大小,保持宽高比"""image=Image.open(path).convert("RGB")width,height=image.sizeifwidth>height:new_width=max_side new_height=int(height*(max_side/width))else:new_height=max_side new_width=int(width*(max_side/height))return[image.resize((new_width,new_height),Image.Resampling.LANCZOS)]defmain():# 加载模型和处理器model=Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME,dtype=torch.float32,device_map=DEVICE)processor=AutoProcessor.from_pretrained(MODEL_NAME)# 构建消息file='outputs/ppl_vlm_qwen3-vl-2b-axera-384/vit/0000.png'messages=[{"role":"user","content":[{"type":"image","image":file},{"type":"text","text":"描述这张图片"},],}]# 应用chat模板text=processor.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)# 处理图片image_inputs=resize_image(file,IMAGE_SIZE)inputs=processor(text=[text],images=image_inputs,return_tensors="pt").to(DEVICE)gen_idx=inputs['input_ids'].shape[1]# 生成文本generated_ids=model.generate(**inputs,max_new_tokens=256)generated_ids_trimmed=[out_ids[len(in_ids):]forin_ids,out_idsinzip(inputs.input_ids,generated_ids)]output_text=processor.batch_decode(generated_ids_trimmed,skip_special_tokens=True,clean_up_tokenization_spaces=False)[0]# 计算PPLtext_with_response=text+output_text image_inputs=resize_image(file,IMAGE_SIZE)inputs2=processor(text=[text_with_response],images=image_inputs,return_tensors="pt").to(DEVICE)withtorch.no_grad():outputs=model(**inputs2,max_new_tokens=1)logits=outputs.logits# 计算交叉熵损失shift_labels=inputs2['input_ids'][...,gen_idx+1:].contiguous().to(DEVICE)shift_logits=logits[...,gen_idx:-1,:].contiguous().to(dtype=torch.float32)loss_fct=CrossEntropyLoss()ce_loss=loss_fct(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1))print(f"ce_loss:{ce_loss:.3f}, ppl:{ce_loss.exp():.3f}")if__name__=="__main__":main()