带Tensorflow Backen的Keras:在自定义图层中使用TensorFlow

我试图在Keras模型中添加自定义图层,因此我根据TensorFlow教程 https://www.tensorflow.org/guide/keras/custom_layers_and_models#layers_are_recursively_composable 编写了代码。

但是我遇到了这个错误: AttributeError:“ tuple”对象没有属性“ layer”

这是我的自定义图层,我希望我的图层可以递归组合。

class Hw_flatten(keras.layers.Layer):
    def __init__(self):
        super(Hw_flatten,self).__init__()

    def call(self,inputs,**kwargs):
        return tf.reshape(inputs,shape=[inputs.shape[0],-1,inputs.shape[-1]])


class Max_Pooling(keras.layers.Layer):
    def __init__(self):
        super(Max_Pooling,**kwargs):
        return tf.layers.max_pooling2d(inputs,pool_size=2,strides=2,padding='SAME')


class Convolution(keras.layers.Layer):
    def __init__(self,use_bias=True):
        super(Convolution,self).__init__()
        self.use_bias = use_bias
        self.weight_init = tf_contrib.layers.xavier_initializer()
        self.weight_regularizer = None
        self.weight_regularizer_fully = None

    def call(self,channels,kernel=4,stride=2,**kwargs):
        return tf.layers.conv2d(inputs=inputs,filters=channels,kernel_size=kernel,kernel_initializer=self.weight_init,kernel_regularizer=self.weight_regularizer,strides=stride,use_bias=self.use_bias)


class google_attention(keras.layers.Layer):
    def __init__(self,output_dim,**kwargs):
        super(google_attention,self).__init__(**kwargs)
        self.shape = (64,64,128)
        self.channels = 1024
        self.name = 'attention'
        self.output_dim = output_dim
        self.conv = Convolution()
        self.max_pooling = Max_Pooling()
        self.hw_flatten = Hw_flatten()

    def build(self,input_shape): #add weight
        super(google_attention,self).build(input_shape)
        self.gamma = K.variable([0.0]) # tf.get_variable("gamma",[1],initializer=tf.constant_initializer(0.0))
        self.trainable_weights = [self.gamma]

    def call(self,**kwargs):
        f = self.conv(inputs,channels=self.channels // 8,kernel=1,stride=1)  # [bs,h,w,c']
        f = self.max_pooling(f)

        g = self.conv(inputs,c']

        h = self.conv(inputs,channels=self.channels // 2,c]
        h = self.max_pooling(h)

        # N = h * w
        s = tf.matmul(self.hw_flatten(g),self.hw_flatten(f),transpose_b=True)  # # [bs,N,N]

        beta = tf.nn.softmax(s)  # attention map

        o = tf.matmul(beta,self.hw_flatten(h))  # [bs,C]

        o = tf.reshape(o,shape=[-1,self.shape[0],self.shape[1],self.shape[2] // 2])  # [bs,C]
        o = self.conv(o,channels=self.output_dim,stride=1)
        x = self.gamma * o + inputs

        return x

    inputs = Input((img_cols,img_rows,IN_CH))
    e1 = BatchNormalization()(inputs)
    e1 = Convolution2D(64,4,subsample=(2,2),activation='relu',init='uniform',border_mode='same')(e1)
    e1 = BatchNormalization()(e1)
    e2 = Convolution2D(128,border_mode='same')(e1)
    e2 = BatchNormalization()(e2)
    atten = google_attention(128)
    e2 = atten(e2)
    model = Model(input=inputs,output=e2)

完整的错误是:

tracking <tf.Variable 'attention/Variable:0' shape=(1,) dtype=float32> gamma
WARNING:tensorflow:From E:\elts\cgan\custom_layer.py:110: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Traceback (most recent call last):
  File "E:/elts/cgan/custome_class.py",line 376,in <module>
    train(BATCH_SIZE)
  File "E:/elts/cgan/custome_class.py",line 277,in train
    generator = generator_model()
  File "E:/elts/cgan/custome_class.py",line 144,in generator_model
    e2 = atten(e2)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\keras\engine\base_layer.py",line 489,in __call__
    output = self.call(inputs,**kwargs)
  File "E:\elts\cgan\custom_layer.py",line 130,in call
    f = self.conv(inputs,c']
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\keras\engine\base_layer.py",line 110,in call
    strides=stride,use_bias=self.use_bias)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\deprecation.py",line 324,in new_func
    return func(*args,**kwargs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\layers\convolutional.py",line 424,in conv2d
    return layer.apply(inputs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py",line 1479,in apply
    return self.__call__(inputs,*args,**kwargs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\layers\base.py",line 537,in __call__
    outputs = super(Layer,self).__call__(inputs,**kwargs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py",line 663,in __call__
    inputs,outputs,args,kwargs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py",line 1708,in _set_connectivity_metadata_
    input_tensors=inputs,output_tensors=outputs,arguments=kwargs)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py",line 1795,in _add_inbound_node
    input_tensors)
  File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\nest.py",line 515,in map_structure
    structure[0],[func(*x) for x in entries],File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\nest.py",in <listcomp>
    structure[0],File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py",line 1794,in <lambda>
    inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,AttributeError: 'tuple' object has no attribute 'layer'

任何评论或建议都将受到高度赞赏。谢谢!!!

vc2005_liu 回答:带Tensorflow Backen的Keras:在自定义图层中使用TensorFlow

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/3157745.html

大家都在问