我可以从tensorflow 2.0数据集中获取numpy数组吗?

我可以从numpy获得一个tensorflow dataset数组吗?在下面的示例中,我可以迭代并从每个numpy array中获取一个tensor。但是我可以直接从dataset获得它吗?

>>> X = tf.reshape(tf.range(2*3),(2,3))
<tf.Tensor: id=33,shape=(2,3),dtype=int32,numpy=
 array([[0,1,2],[3,4,5]],dtype=int32)>

>>> dataset = tf.data.Dataset.from_tensor_slices(X)
<TensorSliceDataset shapes: (3,),types: tf.int32>

>>> t = next(iter(dataset))
<tf.Tensor: id=40,shape=(3,numpy=array([0,dtype=int32)> 

>>> t.numpy()
array([0,dtype=int32)
wangxiqi198706 回答:我可以从tensorflow 2.0数据集中获取numpy数组吗?

一种可能的解决方案(请参见here

def dataset_to_numpy_util(dataset,N):
    dataset = dataset.unbatch().batch(N)
    for images,labels in dataset:
        numpy_images = images.numpy()
        numpy_labels = labels.numpy()
        break;  
    return numpy_images,numpy_labels
本文链接:https://www.f2er.com/3151495.html

大家都在问