%%time
# model vars
modelPath = './model/preTrainedImageNetGooglenet.onnx'
labelsPath = './model/labels.csv'
hasHeader = 1
#data vars
image_folder = './ImageFolder/'
EXT = ("jfif","jpg")
CPU times: user 6 µs, sys: 1 µs, total: 7 µs
Wall time: 34.3 µs
%%time
!pip install onnx
!pip install onnxruntime
import onnx
import numpy as np
from PIL import Image
import os as os
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
import csv
Collecting onnx
Downloading onnx-1.10.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (12.7 MB)
|████████████████████████████████| 12.7 MB 10.5 MB/s
Requirement already satisfied: numpy>=1.16.6 in /shared-libs/python3.7/py/lib/python3.7/site-packages (from onnx) (1.19.5)
Requirement already satisfied: six in /shared-libs/python3.7/py-core/lib/python3.7/site-packages (from onnx) (1.16.0)
Requirement already satisfied: typing-extensions>=3.6.2.1 in /shared-libs/python3.7/py-core/lib/python3.7/site-packages (from onnx) (3.10.0.2)
Requirement already satisfied: protobuf in /shared-libs/python3.7/py/lib/python3.7/site-packages (from onnx) (3.17.3)
Installing collected packages: onnx
Successfully installed onnx-1.10.2
WARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.
You should consider upgrading via the '/root/venv/bin/python -m pip install --upgrade pip' command.
Collecting onnxruntime
Downloading onnxruntime-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
|████████████████████████████████| 4.8 MB 33.7 MB/s
Requirement already satisfied: flatbuffers in /shared-libs/python3.7/py/lib/python3.7/site-packages (from onnxruntime) (1.12)
Requirement already satisfied: protobuf in /shared-libs/python3.7/py/lib/python3.7/site-packages (from onnxruntime) (3.17.3)
Requirement already satisfied: numpy>=1.16.6 in /shared-libs/python3.7/py/lib/python3.7/site-packages (from onnxruntime) (1.19.5)
Requirement already satisfied: six>=1.9 in /shared-libs/python3.7/py-core/lib/python3.7/site-packages (from protobuf->onnxruntime) (1.16.0)
Installing collected packages: onnxruntime
Successfully installed onnxruntime-1.9.0
WARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.
You should consider upgrading via the '/root/venv/bin/python -m pip install --upgrade pip' command.
CPU times: user 1.13 s, sys: 129 ms, total: 1.26 s
Wall time: 18.6 s
%%time
def loadmodel(modelPath,labelsPath,hasHeader):
# define network
# load and check the model
# load the inference module
onnx.checker.check_model(modelPath)
sess = InferenceSession(modelPath)
# Determine the name of the input and output layers
inname = [input.name for input in sess.get_inputs()]
outname = [output.name for output in sess.get_outputs()]
# auxiliary function to load labels file
def extractLabels( filename , hasHeader ):
file = open(filename)
csvreader = csv.reader(file)
if (hasHeader>0):
header = next(csvreader)
#print(header)
rows = []
for row in csvreader:
rows.append(row)
#print(rows)
file.close()
return rows
# Get labels
labels = extractLabels(labelsPath,hasHeader)
# Extract information on the inputSize =(width, heigh) and numChannels = 3(RGB) or 1(Grayscale)
for inp in sess.get_inputs():
inputSize = inp.shape
numChannels = inputSize[1]
inputSize = inputSize[2:4]
return sess,inname,outname,numChannels,inputSize,labels
def getData(image_folder,EXT,inputSize):
def getImagesFromFolder(EXT):
imageList = os.listdir(image_folder)
if (not(isinstance(EXT, list)) and not(isinstance(EXT,tuple))):
ext = [EXT]
fullFilePath = [os.path.join(image_folder, f)
for ext in EXT for f in imageList if os.path.isfile(os.path.join(image_folder, f)) & f.endswith(ext)]
return fullFilePath
def imFile2npArray(imFile,inputSize):
data = np.array([
np.array(
Image.open(fname).resize(inputSize),
dtype=np.int64)
for fname in fullFilePath
])
X=data.transpose(0,3,1,2)
X = X.astype(np.float32)
return X, data
fullFilePath = getImagesFromFolder(EXT)
X, data = imFile2npArray(fullFilePath,inputSize)
return X,data,fullFilePath
CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 7.63 µs
%%time
# run code
sess,inname,outname,numChannels,inputSize,labels = loadmodel(modelPath,labelsPath,hasHeader)
X,data,fullFilePath = getData(image_folder,EXT,inputSize)
print("inputSize: " + str(inputSize))
print("numChannels: " + str(numChannels))
print("inputName: ", inname[0])
print("outputName: ", outname[0])
inputSize: [224, 224]
numChannels: 3
inputName: data
outputName: prob
CPU times: user 321 ms, sys: 81.3 ms, total: 402 ms
Wall time: 423 ms
%%time
#data_output = sess.run(outname, {inname: X[0]})
out = sess.run(None, {inname[0]: X})
out=np.asarray(out[0])
print(out.shape)
IND = []
PROB= []
for i in range(out.shape[0]):
ind=np.where(out[i] == np.amax(out[i]))
IND.append(ind[0][0])
PROB.append(out[i,ind[0][0]])
l = [labels[ind] for ind in IND]
print([labels[ind] for ind in IND])
print(IND)
print(PROB)
(5, 1000)
[['tabby'], ['pug'], ['tabby'], ['daisy'], ['tank']]
[281, 254, 281, 985, 847]
[0.4752808, 0.9636962, 0.28503296, 0.9998487, 0.668438]
CPU times: user 533 ms, sys: 12.6 ms, total: 546 ms
Wall time: 546 ms
%%time
plt.figure(figsize=(15,15))
if data.shape[0]>=6:
nPlots=6
subArray=[2,3]
else:
nPlots=data.shape[0]
subArray = [1, nPlots]
for i in range(nPlots):
plt.subplot(subArray[0],subArray[1],i+1)
plt.imshow(data[i])
plt.axis('off')
plt.title(l[i][0] + ' --- ' + str(round(100*PROB[i])) + '%')
plt.show()
CPU times: user 513 ms, sys: 129 ms, total: 642 ms
Wall time: 546 ms