Tensorflow 2.3.1 mutliGPU NaN 损失值

关于该问题的报告很少,但仍然没有找到答案的运气。简单地说,这里是简短的代码片段:

import tensorflow as tf
from tensorflow.keras import layers
print(tf.__version__)
# 2.3.1
mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
    model = tf.keras.Sequential([tf.keras.layers.Dense(1,input_shape=(1,))])

model.compile(loss='mse',optimizer='sgd')
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(100).batch(10)
model.fit(dataset,epochs=4)

执行后我得到

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0','/job:localhost/replica:0/task:0/device:GPU:1')

Epoch 1/4
INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl,num_packs = 1
INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl,num_packs = 1
10/10 [==============================] - 1s 93ms/step - loss: 807385211185512087331799040.0000
Epoch 2/4
10/10 [==============================] - 1s 93ms/step - loss: nan
Epoch 3/4
10/10 [==============================] - 1s 93ms/step - loss: nan
Epoch 4/4
10/10 [==============================] - 1s 93ms/step - loss: nan
10/10 [==============================] - 0s 48ms/step - loss: nan

没有策略输出看起来正常,损失计算正常

Epoch 1/4
10/10 [==============================] - 0s 2ms/step - loss: 4.2581
Epoch 2/4
10/10 [==============================] - 0s 2ms/step - loss: 1.8821
Epoch 3/4
10/10 [==============================] - 0s 2ms/step - loss: 0.8319
Epoch 4/4
10/10 [==============================] - 0s 2ms/step - loss: 0.3677
10/10 [==============================] - 0s 1ms/step - loss: 0.2284

作为运行时环境,我使用来自 Nvidia GPU Cloud 的 tensorflow 容器 nvcr.io/nvidia/tensorflow:20.10-tf2-py3 - 因此它是最新的并且与所有类型的驱动程序兼容。我也试过更新版本 20.12-tf2-py3

bohu1025 回答:Tensorflow 2.3.1 mutliGPU NaN 损失值

在使用 2x AMD GPU 的 tensorflow 2.4 上进行多 GPU 训练时遇到了相同的问题。与此同时,我设法修复了我的系统。使用固定系统,您的代码片段可以正确运行(不会产生 NaN)。

事实证明,GPU之间的变量数据传输无声无息地失败了。你可以通过执行

来检查你是否有同样的问题
import tensorflow as tf

v = tf.Variable(1.0)

with tf.device('/gpu:1'):
    print(v)

由于设备到设备传输失败,它没有给出预期的结果 <tf.Variable 'Variable:0' shape=() dtype=float32,numpy=1.0>。如果您的输出也有不匹配的 numpy 值,则说明您的系统存在问题。

本文链接:https://www.f2er.com/1214707.html

大家都在问