如何从CNN中正确提取权重?

首先,我训练了cnn体系结构:

adam =  optimizers.Adam(learning_rate=0.0001,beta_1=0.9,beta_2=0.999,amsgrad=False)
model = Sequential()
model.add(Conv2D(20,(3,3),activation='relu',input_shape =(5,5,1),padding='same',kernel_initializer='he_normal'))
model.add(Conv2D(30,kernel_initializer='he_normal'))
model.add(Dropout(0.5))
#model.add(MaxPooling2D(2,2)) # because the ROI is already small,we don't need subsampling
model.add(flatten())
model.add(Dense(1,activation='sigmoid',kernel_initializer='he_normal'))
model.summary()
#plot_model(model,to_file='model.png',show_shapes=True,show_layer_names=True)
# compile the model
model.compile(loss='binary_crossentropy',optimizer= adam,metrics=['accuracy'])

history = model.fit(X_train,Y_train,epochs=200,callbacks=[model_checkpoint],batch_size=1,verbose=1,shuffle=True,validation_split=0.5)

与此同时,由于以下原因,我保存了每个时期的所有权重:

model_checkpoint=ModelCheckpoint('model_test{epoch:02d}.h5',save_freq=1,save_weights_only=True)

然后我提取了权重,例如第一个时期“ model_test01.h5”的权重

import h5py
import numpy as np
def isGroup(obj):
    if isinstance(obj,h5py.Group):
        return True
    return False

def isDataset(obj):
    if isinstance(obj,h5py.Dataset):
        return True
    return False

def getdatasetFromGroup(datasets,obj):
    if isGroup(obj):
        for key in obj:
            x = obj[key]
            getdatasetFromGroup(datasets,x)
    else:
        datasets.append(obj)

def getWeightsForLayer(layerName,filename):
   weights = []
   with h5py.File(filename,mode='r') as f:
       for key in f:
           if layerName in key:
              obj = f[key]
              datasets = []
              getdatasetFromGroup(datasets,obj)

              for dataset in datasets:
                  w = np.array(dataset)
                  weights.append(w)
   return weights
           #print(key,f[key])
           #o = f[key]
           #for key1 in o:
               #print(key1,o[key1])
               #r = o[key1]
               #for key2 in r:
                   #print(key2,r[key2])
weights = getWeightsForLayer("conv2d_6","./model_test01.h5")
#for w in weights:
    #print(w.shape)
print(weights)

但是我无法理解输出,因为列表“权重”包含两个float32元素(基本上是两个numpy数组),第一个包含20个元素(猜测20是第一个卷积层中过滤器的数量),第二个尺寸为(3,3,1,20)的一个(因此不可能打开)。我如何理解此输出?

wwwwwssssss 回答:如何从CNN中正确提取权重?

您拥有的两个“ float32元素”分别对应于滤波器的权重和conv层的偏差。 过滤器权重将具有形状(3、3、1、20),而偏差将具有形状(20),因为您有20个过滤器,并且有一个偏差值每个钳工。

(3、3、1、20)的表示形式为(过滤器宽度,过滤器高度,过滤器深度,过滤器数量)

本文链接:https://www.f2er.com/2623369.html

大家都在问