使用dataset.shuffle(100000),但数据并没有打乱
来源:4-6 tfrecord基础API使用
wxz123
2020-04-19
环境为python3.6 tensorflow1.13
老师我在一个数据集上做数据处理阶段用一下代码转为tfrecord格式的数据
import numpy as np
a=np.array([[1,2,3],[1,2,3]])
cls_dict={'cat':0,'dog':1,'sheep':2}
input_dict={'cat/1.jpg':a,'cat/2.jpg':a,'dog/3.jpg':a,'sheep/4.jpg':a,'sheep/5.jpg':a}
path='.'
def to_tfrecords_reshape(input_dict,path,compression_type = None):
options = tf.io.TFRecordOptions(
compression_type = compression_type)
output_path=os.path.join(path,'output_reshape.tfrecords')
i=0
with tf.io.TFRecordWriter(output_path, options) as writer:
for file_path ,feature in input_dict.items():
cls=file_path.split(r'/')[-2]
cls_id=[cls_dict[cls]]
feature_reshape=feature.reshape(-1)
example = tf.train.Example(features =
tf.train.Features(feature = {
'cls':tf.train.Feature(int64_list=tf.train.Int64List(value =cls_id)),
'feature_reshape':tf.train.Feature(float_list=tf.train.FloatList(value=feature_reshape))
}))
writer.write(example.SerializeToString())
print(i)
i+=1
to_tfrecords_reshape(input_dict,path)
由于数据集没法传,我手写了几条数据。。。
用下面的代码再提取出来
expected_features = {
"cls": tf.io.FixedLenFeature([], dtype=tf.int64),
"feature_reshape": tf.io.VarLenFeature(dtype=tf.float32)
}
def parse_example(serialized_example):
example = tf.io.parse_single_example(serialized_example,
expected_features)
features = tf.sparse.to_dense(example["feature_reshape"],
default_value=0)
return example["cls"],features
def tfrecords_to_dataset(filenames, batch_size=5, n_parse_threads=1,
shuffle_buffer_size=1):
dataset=tf.data.TFRecordDataset(
filenames)
dataset = dataset.map(parse_example)
#num_parallel_calls=n_parse_threads)
dataset.shuffle(100000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(5)
return dataset
filenames='output_reshape.tfrecords'
dataset=tfrecords_to_dataset(filenames)
dataset.shuffle(100000)
dataset_iter = dataset.make_one_shot_iterator()
cls,feature = dataset_iter.get_next()
with tf.Session() as sess:
for i in range(1):
cls_val,feature_val = sess.run([cls,feature])
print(cls_val)
print(feature_val)
得到的结果如下
[0 0 1 2 2]
[[ 1. 2. 3. 1. 2. 3.]
[ 1. 2. 3. 1. 2. 3.]
[ 1. 2. 3. 1. 2. 3.]
[ 1. 2. 3. 1. 2. 3.]
[ 1. 2. 3. 1. 2. 3.]]
我的问题:无论我运行多少次,输出的数据永远都是顺序输出的,dataset.shuffle(100000)根本没起到作用,(因为在我处理自己的数据集时,我生产tfrecord文件时还把路径保存下来了,但提取tfrecord文件时输出的路径不管运行几次顺序都完全一致,且都是保存前顺序的路径,dataset.shuffle(100000)根本没起到作用),从本例中也可看出输出的类别为
[0 0 1 2 2]
完全就是存数据前的类别序号,即
cls_dict={'cat':0,'dog':1,'sheep':2}
input_dict={'cat/1.jpg':a,'cat/2.jpg':a,'dog/3.jpg':a,'sheep/4.jpg':a,'sheep/5.jpg':a}
#input_dict即对应着[0 0 1 2 2],如果shuffle有效果的话,输出应该是一个将[0 0 1 2 2]打乱的数组吧
请老师帮我看看我这个保存和提取数据的代码哪里存在问题,十分感谢!!!
写回答
1回答
-
正十七
2020-04-21
问题可能在两个地方,第一个是shuffle要有返回值,即dataset = dataset.shuffle(), 第二个是如果想每次遍历数据集数据的顺序都不一样,需要设置reshuffle_each_iteration参数:
dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=True) dataset = dataset.repeat(2)
你的问题应该是地一个。
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle
00
相似问题