我正在尝试在4个GPU上训练大型模型。该模型太大,以至于无法在单个GPU中容纳。所以我要做的是将不同的层放置到不同的GPU中。它适用于小批量生产,但是对于大批量生产,GPU 0始终会因OOM错误而首先破裂。我进行了测试,并从模型中删除了GPU 0,在调用fit函数之前,它没有显示任何用法(参见图片)。但是,当模型开始加载批处理时,GPU0已与15488 MiB一起使用,仍然会导致OOM错误。谁能告诉我为什么会这样,以及是否有办法解决这个问题?
keras 2.2.5 tensorflow 1.15