为我提供了具有随机行的2D张量。应用tf.math.greater()
和tf.cast(tf.int32)
之后,剩下的张量为0和1。我现在想将归约和应用于该矩阵,但有一个条件:如果至少有一个1求和,并且后面跟随一个0,我也要删除所有后面的1,这意味着1 0 1
应该导致1
而不是2
。
我试图用tf.scan()
解决问题,但是我仍然无法提出一个能够处理以0开头的函数,因为该行可能看起来像:0 0 0 1 0 1
一种想法是将矩阵的下部设置为1(据我所知,对角线剩下的一切始终为0),然后运行类似tf.scan()
的函数来滤除斑点(请参见代码和错误消息)下面)。
Let z be the matrix after tf.cast.
helper = tf.matrix_band_part(tf.ones_like(z),-1,0)
z = tf.math.logical_or(tf.cast(z,tf.bool),tf.cast(helper,tf.bool))
z = tf.cast(z,tf.int32)
z = tf.scan(lambda a,x: x if a == 1 else 0,z)
结果:
ValueError: Incompatible shape for value ([]),expected ([5])