inference.py篇
admin
2024-02-01 11:15:21

inference.py 篇

目录:

  • 前言
  • 思考自己需要载入的超参
  • 书写代码
  • 函数手册

前言

在该模块中加载训练好的模型,对测试集的image进行推理。

思考自己需要载入的超参

该模块的书写,是train的简约版,例如你可能需要设置和train相同的batch_sizedevicedataloader等信息,但是这次你不需要设置epoch等信息,对模型的参数进行优化等。

书写代码

书写顺序如下:

argparse()方法收集需要传递的所有参数,传入main函数中(可选)。

main函数中思路如下:

  1. 写路径等信息
  2. 书写dataloder。设置transformsdatasetdataloaderbatch_size等参数,因为dataloader中要用到。
  3. 设置其余超参,如device等,这次你必须要加载train中产生的预训练权重。
  4. 对测试集进行推理

下以AlexNet中的inference.py为例:

# add path
import os, sys
root_path = os.path.dirname(os.path.dirname(__file__))
project_path = os.path.dirname(__file__)
sys.path.append(project_path)
# add module
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import torch
import numpy as np
from model import AlexNetdef parse_args():"""get your args"""def convert_image(image_path:str = ""):"""transform png to jpg"""def main():# 路径root_path       = os.path.dirname(os.path.dirname(__file__))project_path    = os.path.dirname(__file__)weight_path     = os.path.join(root_path, "weight", "AlexNet_2.pth")image_path      = "/home/yingmuzhi/AlexNet/daisy.jpg"# 加载预测图片img             = Nonedata_transform  = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = Image.open(image_path)print(np.array(img).shape)img = data_transform(img)   # 只接受[height, width, channel=3]的图片, 即RGB的jpgimg = torch.unsqueeze(img, dim = 0) # 传入网络需要[batch, channel, height, width]# 加载json文件try:json_file = open(project_path + "/class_indices.json","r")class_indict = json.load(json_file)except Exception as e:print(e)exit(-1)# 测试参数net = AlexNet(num_classes=2)net.load_state_dict(torch.load(weight_path))net.eval()    # 关闭dropout层并且不会梯度回传with torch.no_grad():# predict classoutput = net(img)# print(output.shape)output = torch.squeeze(output)# print(output.shape)predict = torch.softmax(output, dim = 0)# print(predict.shape)predict_cla = torch.argmax(predict).numpy()print(class_indict[str(predict_cla)], predict[predict_cla].item())if __name__ == "__main__":args = parse_args()main(args)

函数手册

相关内容

热门资讯

原创 夏... 夏天湿热重、脾胃易虚寒,这4道汤健脾祛湿、暖胃护胃、清热不伤阳,适合连续两个月常喝,步骤清晰、做法简...
明日四月十六,记得“吃4样,做... 明日农历四月十六,记得“吃4样,做1事”五谷丰登迎福气,老传统别丢! 时光如梭,转眼间来到了农历四月...
今年目标全国销售网点突破200... 5月16日下午6点,贵阳市吾茶白·贵茶潮饮烘焙概念店里排起小队。 “就要这款,上次喝完一直惦记着。”...
原创 淄... 很多人认识淄博只靠烧烤但真正撑起淄博饮食底蕴的从来不是网红热度而是一代代扎根老城的老字号烟火。这些老...
原创 夏... “赤日炎炎似火烧”,这话一到夏天,可算是说到大家心坎里去了。天热起来,不光人没精神,连胃口也跟着变差...