我是ML和TensorFlow的新手。我正在尝试构建一个cnn,以便对损坏的图像进行分类,类似于张量流中的剪刀石头布教程,但只有两个类别。
模型体系结构
train_generator = training_datagen.flow_from_directory(
TRAINING_DIR,target_size=(150,150),class_mode='categorical'
)
validation_generator = validation_datagen.flow_from_directory(
VALIDATION_DIR,class_mode='categorical'
)
model = tf.keras.models.Sequential([
# Note the input shape is the desired size of the image 150x150 with 3 bytes color
# This is the first convolution
tf.keras.layers.Conv2D(64,(3,3),activation='relu',input_shape=(150,150,3)),tf.keras.layers.MaxPooling2D(2,2),# The second convolution
tf.keras.layers.Conv2D(64,activation='relu'),# The third convolution
tf.keras.layers.Conv2D(128,# The fourth convolution
tf.keras.layers.Conv2D(128,# flatten the results to feed into a DNN
tf.keras.layers.flatten(),tf.keras.layers.Dropout(0.5),# 512 neuron hidden layer
tf.keras.layers.Dense(512,tf.keras.layers.Dense(2,activation='softmax')
])
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer='rmsprop',metrics=['accuracy'])
history = model.fit_generator(train_generator,epochs=25,validation_data = validation_generator,verbose = 1)
model.save("rps.h5")
我所做的唯一更改是将输入形状更改为(150,1)更改为(150,3),并将最后一层的输出从3更改为 2个神经元。培训使我在每个课程中 600张图像的数据集始终保持 90以上的准确性。但是,当我在本教程中使用代码进行预测时,即使对于数据集中的数据,它也会给我带来非常错误的值。
预测
TensorFlow教程中的原始代码
for file in onlyfiles:
path = fn
img = image.load_img(path,3)) # changed target_size to (150,3)) from (150,150 )
x = image.img_to_array(img)
x = np.expand_dims(x,axis=0)
images = np.vstack([x])
classes = model.predict(images,batch_size=10)
print(fn)
print(classes)
我相信target_size从(150,150)更改为(150,150,3)),因为我的输入是3通道图像,所以
结果
它甚至为数据集中的图像给出了非常错误的[0,1] [0,1]值
但是当我将代码更改为此
for file in onlyfiles:
path = fn
img = image.load_img(path,3))
x = image.img_to_array(img)
x = np.expand_dims(x,axis=0)
x /= 255.
classes = model.predict(images,batch_size=10)
print(fn)
print(classes)
在这种情况下,值类似于
[[9.9999774e-01 2.2242968e-06]]
[[9.9999785e-01 2.1864464e-06]]
[[9.9999785e-01 2.1641024e-06]]
一两个错误,但是非常正确
所以我的问题是,即使最后一次激活是softmax,为什么现在以十进制值显示,我进行预测的方式是否存在逻辑错误?我也尝试过二进制文件,但没有太大区别。