使用dataset.shuffle(100000),但数据并没有打乱

来源:4-6 tfrecord基础API使用

wxz123

2020-04-19

环境为python3.6 tensorflow1.13
老师我在一个数据集上做数据处理阶段用一下代码转为tfrecord格式的数据

import numpy as np 
a=np.array([[1,2,3],[1,2,3]])
cls_dict={'cat':0,'dog':1,'sheep':2}
input_dict={'cat/1.jpg':a,'cat/2.jpg':a,'dog/3.jpg':a,'sheep/4.jpg':a,'sheep/5.jpg':a} 
path='.' 
def to_tfrecords_reshape(input_dict,path,compression_type = None):
    options = tf.io.TFRecordOptions(
        compression_type = compression_type)
    output_path=os.path.join(path,'output_reshape.tfrecords')
    i=0
    with tf.io.TFRecordWriter(output_path, options) as writer:
        for file_path ,feature in input_dict.items():
            cls=file_path.split(r'/')[-2]
            cls_id=[cls_dict[cls]]
            feature_reshape=feature.reshape(-1)
    
            example = tf.train.Example(features = 
                                       tf.train.Features(feature = {
                                               'cls':tf.train.Feature(int64_list=tf.train.Int64List(value =cls_id)),
                                               'feature_reshape':tf.train.Feature(float_list=tf.train.FloatList(value=feature_reshape))
                                               }))
            writer.write(example.SerializeToString())
            print(i)
            i+=1
to_tfrecords_reshape(input_dict,path)

由于数据集没法传,我手写了几条数据。。。
用下面的代码再提取出来

expected_features = {
    "cls": tf.io.FixedLenFeature([], dtype=tf.int64),
    "feature_reshape": tf.io.VarLenFeature(dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    features = tf.sparse.to_dense(example["feature_reshape"],
                           default_value=0)
    return  example["cls"],features


def tfrecords_to_dataset(filenames, batch_size=5, n_parse_threads=1,
                             shuffle_buffer_size=1):
    dataset=tf.data.TFRecordDataset(
            filenames)
    dataset = dataset.map(parse_example)
                          #num_parallel_calls=n_parse_threads)
    dataset.shuffle(100000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(5)
    return dataset

filenames='output_reshape.tfrecords'
dataset=tfrecords_to_dataset(filenames)
dataset.shuffle(100000)
dataset_iter = dataset.make_one_shot_iterator()
cls,feature = dataset_iter.get_next()

with tf.Session() as sess:
    for i in range(1):
        cls_val,feature_val = sess.run([cls,feature])
        print(cls_val)
        print(feature_val)

得到的结果如下

[0 0 1 2 2]
[[ 1.  2.  3.  1.  2.  3.]
 [ 1.  2.  3.  1.  2.  3.]
 [ 1.  2.  3.  1.  2.  3.]
 [ 1.  2.  3.  1.  2.  3.]
 [ 1.  2.  3.  1.  2.  3.]]

我的问题:无论我运行多少次,输出的数据永远都是顺序输出的,dataset.shuffle(100000)根本没起到作用,(因为在我处理自己的数据集时,我生产tfrecord文件时还把路径保存下来了,但提取tfrecord文件时输出的路径不管运行几次顺序都完全一致,且都是保存前顺序的路径,dataset.shuffle(100000)根本没起到作用),从本例中也可看出输出的类别为

[0 0 1 2 2]

完全就是存数据前的类别序号,即

cls_dict={'cat':0,'dog':1,'sheep':2}
input_dict={'cat/1.jpg':a,'cat/2.jpg':a,'dog/3.jpg':a,'sheep/4.jpg':a,'sheep/5.jpg':a} 
#input_dict即对应着[0 0 1 2 2],如果shuffle有效果的话,输出应该是一个将[0 0 1 2 2]打乱的数组吧

请老师帮我看看我这个保存和提取数据的代码哪里存在问题,十分感谢!!!

写回答

1回答

正十七

2020-04-21

问题可能在两个地方,第一个是shuffle要有返回值,即dataset = dataset.shuffle(), 第二个是如果想每次遍历数据集数据的顺序都不一样,需要设置reshuffle_each_iteration参数:

dataset = tf.data.Dataset.range(3) 
dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 
dataset = dataset.repeat(2)

你的问题应该是地一个。

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle

0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程