频谱数据集上的VGG16

我正在遵循Rajsha编写的指南: https://github.com/rajshah4/image_keras/blob/master/notebook_extras.ipynb

这个想法是将VGG16应用于我的由频谱图组成的数据集,并让其在正常和异常两类之间进行决策。

但是,该模型没有学习,尽管我处于顶层,但我仍然获得了大约0.5 val_acc

我做错什么了吗?我将代码留在下面:

# dimensions of our images
img_width,img_height = 240,240

train_data_dir = '/content/gdrive/My Drive/Melspec/melspecimages/train'
validation_data_dir = '/content/gdrive/My Drive/Melspec/melspecimages/val'

batch_size = 32
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

model_vgg = applications.VGG16(include_top=False,weights='imagenet',input_shape=(240,240,3))

model_vgg.trainable=False

train_generator_bottleneck = datagen.flow_from_directory(
        train_data_dir,target_size=(img_width,img_height),batch_size=batch_size,class_mode='binary',shuffle=True)

validation_generator_bottleneck = datagen.flow_from_directory(
        validation_data_dir,shuffle=False) 

train_samples = 30272
validation_samples = 7584

bottleneck_features_train = model_vgg.predict_generator(train_generator_bottleneck,train_samples // batch_size)
np.save(open('/content/gdrive/My Drive/Melspec/spec_vgg_bottleneck_features_train.npy','wb'),bottleneck_features_train)

bottleneck_features_validation = model_vgg.predict_generator(validation_generator_bottleneck,validation_samples // batch_size)
np.save(open('/content/gdrive/My Drive/Melspec/spec_vgg_bottleneck_features_validation.npy',bottleneck_features_validation)

train_data = np.load(open('/content/gdrive/My Drive/Melspec/spec_vgg_bottleneck_features_train.npy','rb'))
train_labels = np.array([0] * (train_samples // 2) + [1] * (train_samples // 2))

validation_data = np.load(open('/content/gdrive/My Drive/Melspec/spec_vgg_bottleneck_features_validation.npy','rb'))
validation_labels = np.array([0] * (validation_samples // 2) + [1] * (validation_samples // 2))

model_top = Sequential()
model_top.add(flatten(input_shape=train_data.shape[1:]))
model_top.add(Dense(256,activation='relu'))
model_top.add(Dropout(0.5))
model_top.add(Dense(1,activation='sigmoid'))

model_top.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy'])

model_top.fit(train_data,train_labels,epochs=epochs,validation_data=(validation_data,validation_labels))
```
countmachine 回答:频谱数据集上的VGG16

找到答案:我的标签有误。

我在网络上读到,在给train_generator供电时,我们应该使用shuffle = True,但是这些类不是以相同的顺序混合的,只有文件混合在一起,从而导致错误的标签。

我改用shuffle = False,也改用class_mode = None。

我还必须确保数据库中的文件在两个类中的编号相同,并且可以被我的batch_size整除。

希望这对其他初学者有所帮助!

本文链接:https://www.f2er.com/3061336.html

大家都在问