使用tf.function时tensorflow2.0用尽了GPU内存

我使用tensorflow2.0并制作一个磁带渐变程序,并使用@ tf.function获得一个函数。但是随着我的训练,尽管我只使用了550个单词的论文,但Mem仍在增长。我的数据总大小仅为30m,但内存使用量高达290G。此外,GPU使用率也在不断增加。当我完成一个纪元时,它告诉我Gpu内存不足。那么有人可以帮助我解决这个难题吗?

@tf.function(input_signature=train_step_signature)
def train_step(group,inp,tar,label):
    tar_inp = tar[:,:-1]
    tar_real = tar[:,1:]  # sess=tf.compat.v1.Session()
    enc_padding_mask,combined_mask,dec_padding_mask = create_masks(inp,tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        classfication,predictions,_ = transformer(inp,tar_inp,True,enc_padding_mask,dec_padding_mask)
        loss = loss_function(tar_real,predictions)
        loss2 = tf.nn.softmax_cross_entropy_with_logits(label,classfication)
        loss=loss+loss2

    # print(loss,loss2)

    gradients = tape.gradient(loss,transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients,transformer.trainable_variables))
    class_loss(loss2)
    train_loss(loss)
    train_accuracy(tar_real,predictions)

    # gra = tape.gradient(loss2,transformer.trainable_variables)
    # optimizer.apply_gradients(zip(gra,transformer.trainable_variables))
    class_accuracy(tf.argmax(label,1),classfication)`

我使用以下代码来训练tf.function:

tf.compat.v2.summary.trace_on(graph=True,profiler=True)
for epoch in range(EPOCHS):
    start = time.time()
    train_loss.reset_states()
    train_accuracy.reset_states()
    class_loss.reset_states()
    class_accuracy.reset_states()
    # inp -> portuguese,tar -> english
    for (batch,(group,label)) in enumerate(train_dataset):

        train_step(group,label)
        if batch % 50 == 0:
            print(
                'Epoch {} Batch {} correct_Loss {:.4f} Correct_accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(
                    epoch + 1,batch,train_loss.result(),train_accuracy.result(),class_accuracy.result(),class_loss.result()))

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,ckpt_save_path))

    print('Epoch {} correct_Loss {:.4f} correct_accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(epoch + 1,class_loss.result()))

    print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
ludcy2009 回答:使用tf.function时tensorflow2.0用尽了GPU内存

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/3065194.html

大家都在问