如何在TensorFlow中获取日志概率?

我正在尝试将pytorch脚本转换为tensorflow,我需要从分类分发中获取日志概率。但是即使使用相同的种子,张量流计算的对数概率也不同于pytorch的对数概率。这是我到目前为止所做的

import torch 
from torch.distributions import Categorical
import tensorflow as tf
import tensorflow_probability as tfp

torch.manual_seed(1)
tf.random.set_seed(1)

probs =[0.4,0.6]
m = Categorical(torch.tensor(probs))
action = m.sample()

n = tfp.distributions.Categorical(probs)
print("pytorch",m.log_prob(action))
print("Tensorflow",tf.math.log(n.prob(action.item())))
z492572861 回答:如何在TensorFlow中获取日志概率?

tfp.distributions.Categorical(probs)

将日志作为默认参数。它们正在规范化,因此生成的分布概率为[.45,.55]。

您需要将tfp分发构建为:

 tfp.distributions.Categorical(probs=probs)
本文链接:https://www.f2er.com/3147207.html

大家都在问