Keras CNN分类器

我确实对Keras的cnn有疑问,如果您想帮助我,我将非常感谢。

免责声明:我是cnn和Keras的菜鸟,我现在正在学习它们。


我的数据:

2个班级(狗和猫)

交易:每个类别30张图片

测试:每个类别14张图片

有效:每个类别30张图片


我的代码:

data_path = Path("../data")

train_path = data_path / "train"
test_path = data_path / "test"
valid_path = data_path / "valid"

train_batch = ImageDataGenerator().flow_from_directory(directory=train_path,target_size=(200,200),classes=animals,batch_size=10)

valid_batch = ImageDataGenerator().flow_from_directory(directory=valid_path,batch_size=10)

test_path = ImageDataGenerator().flow_from_directory(directory=test_path,batch_size=4)

imgs,labels = next(train_batch)

model = Sequential(
    [Conv2D(32,(3,3),activation="relu",input_shape=(200,200,3)),flatten(),Dense(len(animals),activation='softmax')])

model.compile(Adam(lr=.0001),loss='categorical_crossentropy',metrics=['accuracy'])

model.fit_generator(train_path,steps_per_epoch=4,validation_data=valid_batch,validation_steps=3,epochs=5,verbose=2)

这是我的错误消息:

我已将路径替换为“”

Traceback (most recent call last):
  File "",line 191,in <module>
    model.fit_generator(train_path,verbose=2)
  File "y",line 91,in wrapper
    return func(*args,**kwargs)
  File "",line 1732,in fit_generator
    initial_epoch=initial_epoch)
  File "",line 185,in fit_generator
    generator_output = next(output_generator)
  File "",line 742,in get
    six.reraise(*sys.exc_info())
  File "",line 693,in reraise
    raise value
  File "",line 711,in get
    inputs = future.get(timeout=30)
  File "",line 657,in get
    raise self._value
  File "",line 121,in worker
    result = (True,func(*args,**kwds))
  File "",line 650,in next_sample
    return six.next(_SHARED_SEQUENCES[uid])
TypeError: 'PosixPath' object is not an iterator

有人可以告诉我我做错了什么吗?另外,如果这是一个离题的问题,请告诉我在哪里可以问到。

cntong 回答:Keras CNN分类器

此行不是必需的

imgs,labels = next(train_batch)

来自docs fit_generator第一个参数的

是一个生成器对象,没有提供的字符串。像这样

model.fit_generator(train_path,steps_per_epoch=4,validation_data=valid_batch,validation_steps=3,epochs=5,verbose=2)

,

您遇到的问题是您没有通过训练的生成器,而是文件的路径(您使用的是 train_path 而不是 {{1 }}

使用train_batch时需要为对象传递生成器:

.fit_generator()
本文链接:https://www.f2er.com/3127084.html

大家都在问