自定义TensorFlow Keras优化器

假设我想编写一个符合tf.keras API的自定义优化器类(请注意,我当前正在使用TensorFlow 2.0.0)。我对记录下来的实现方式与实现中的实现方式感到困惑。

tf.keras.optimizers.Optimizer states的文档,

  ### Write a customized optimizer.
  If you intend to create your own optimization algorithm,simply inherit from
  this class and override the following methods:
    - resource_apply_dense (update variable given gradient tensor is dense)
    - resource_apply_sparse (update variable given gradient tensor is sparse)
    - create_slots (if your optimizer algorithm requires additional variables)

但是,当前的tf.keras.optimizers.Optimizer实现并没有定义resource_apply_dense方法,但是它确实定义了一个看起来很私密的_resource_apply_dense方法存根。同样,没有resource_apply_sparsecreate_slots方法,但是有一个_resource_apply_sparse方法存根和一个_create_slots方法调用。

在正式的tf.keras.optimizers.Optimizer子类中(以tf.keras.optimizers.Adam为例),有_resource_apply_dense_resource_apply_sparse_create_slots方法,没有这样的方法没有下划线的方法。

非正式程度较低的tf.keras.optimizers.Optimizer子类(例如,来自TensorFlow Addons的tfa.optimizers.MovingAverage)中也有类似的前导下划线方法。

对我来说,另一个困惑点是TensorFlow Addons优化器覆盖了apply_gradients方法,而tf.keras.optimizers优化器却没有。

此外,我注意到apply_gradients方法calls _create_slotstf.keras.optimizers.Optimizer方法,但是基tf.keras.optimizers.Optimizer类没有_create_slots方法。 因此,似乎_create_slots方法 必须在优化器子类中定义,如果该子类未覆盖apply_gradients


问题

继承tf.keras.optimizers.Optimizer的正确方法是什么?具体来说,

  1. 顶部列出的tf.keras.optimizers.Optimizer文档只是意味着要覆盖他们提到的方法的前导下划线版本(例如,_resource_apply_dense而不是resource_apply_dense)吗?如果是这样,是否有关于这些私有方法的API保证不会在TensorFlow的未来版本中更改其行为?
  2. 除了apply_gradients方法之外,何时还能覆盖_apply_resource_[dense|sparse]
wei981106811 回答:自定义TensorFlow Keras优化器

更新:TF2.2迫使我清理所有实现-因此现在 可以用作TF最佳实践的参考。还添加了以下关于_get_hyper_set_hyper的部分。


我已经在所有主要TF和Keras版本中实现了Keras AdamW-我邀请您研究optimizers_v2.py。几点:

  • 您应该继承OptimizerV2,这实际上是您链接的内容;这是tf.keras优化器的最新和最新基类
  • 您在(1)中是正确的-这是文档错误;这些方法是私有的,因为它们并不意味着用户直接使用。
  • apply_gradients(或任何其他方法)仅在默认值不能满足给定优化器所需的条件时才被覆盖;在您的链接示例中,它只是原始文件的一线附加程序
  • ”“因此,似乎必须在优化器子类中定义_create_slots方法,如果该子类未覆盖apply_gradients”。 –两者无关;这是巧合。

  • _resource_apply_dense_resource_apply_sparse有什么区别?

后者处理稀疏层-例如Embedding-以及其他所有内容; example

  • 何时应使用_create_slots()

定义可训练的 tf.Variable时;例如:重量的一阶和二阶矩(例如Adam)。它使用add_slot()


_get_hyper_set_hyper的比较:它们允许设置和获取Python文字(intstr等),可调用对象和张量。它们的存在主要是为了方便起见:通过_set_hyper设置的任何内容都可以通过_get_hyper进行检索,避免重复样板代码。我对here进行了问答。

,
  1. 是的,这似乎是文档错误。前面的下划线名称是正确的重写方法。与之相关的是非Keras Optimizer,它已定义了所有这些,但未在基类https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/optimizer.py中实现
  def _create_slots(self,var_list):
    """Create all slots needed by the variables.
    Args:
      var_list: A list of `Variable` objects.
    """
    # No slots needed by default
    pass

  def _resource_apply_dense(self,grad,handle):
    """Add ops to apply dense gradients to the variable `handle`.
    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

  def _resource_apply_sparse(self,handle,indices):
    """Add ops to apply sparse gradients to the variable `handle`.
    Similar to `_apply_sparse`,the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.
    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
      indices: a `Tensor` of integral type representing the indices for
       which the gradient is nonzero. Indices are unique.
    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()
  1. 我不知道apply_dense。一方面,如果您重写它,则代码会提到每个副本的DistributionStrategy可能是“危险的”
    # TODO(isaprykin): When using a DistributionStrategy,and when an
    # optimizer is created in each replica,it might be dangerous to
    # rely on some Optimizer methods.  When such methods are called on a
    # per-replica optimizer,an exception needs to be thrown.  We do
    # allow creation per-replica optimizers however,because the
    # compute_gradients()->apply_gradients() sequence is safe.
本文链接:https://www.f2er.com/3135451.html

大家都在问