关于模型的预测
来源:8-20 训练流程代码
Shuren_Yu
2021-02-12
老师,看了您的最后一个预测的代码,如果想预测一张自己的图片,而不是数据集中随机挑选的测试图片,应该在image_caption_eval.ipynb代码中做怎样的调整呢?我尝试了使用feature_extraction.ipynb的代码提取了想要预测图片的特征,同样打包成了pickle文件,但是导入image_caption_eval.ipynb的代码中预测时就会出错,因此非常想知道该如何调整image_caption_eval.ipynb以达到预测任意一张图片的目的。望赐教。
写回答
1回答
-
正十七
2021-02-24
你需要解构ImageCaptionData这个类。在这里你需要的数据就是Image经过vgg提取到的特征数据而已。
在如下的代码中,只有single_img_feature是模型所需要的输入,所以对于一张图像,你需要抽取特征得到single_img_features, 然后就照着运行就可以了啊。
with tf.Session() as sess: sess.run(init_op) logging.info("[*] Reading checkpoint ...") ckpt = tf.train.get_checkpoint_state(output_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(output_dir, ckpt_name)) logging.info("[*] Success Read Checkpoint From %s" % (ckpt_name)) else: raise Exception("[*] Failed load checkpoint") for i in range(test_examples): single_img_features, single_sentence_ids, single_weights, single_img_name = caption_data.next(hps.batch_size) print(single_img_name) pprint.pprint(img_name_to_tokens[single_img_name[0]]) pprint.pprint(img_name_to_token_ids[single_img_name[0]]) embed_img_val = sess.run(embed_img, feed_dict={img_feature: single_img_features}) state_val = np.zeros((1, num_hidden_states)) embed_input_val = embed_img_val generated_sequence = [] for j in range(hps.num_timesteps): logits_val, state_val = sess.run([logits, output_state], feed_dict = { embed_input: embed_input_val, input_state: state_val }) predicted_word_id = np.argmax(logits_val[0]) generated_sequence.append(predicted_word_id) embed_input_val = sess.run(embed_word, feed_dict={word: [[predicted_word_id]]}) pprint.pprint("generated words: ") pprint.pprint(generated_sequence) pprint.pprint(vocab.decode(generated_sequence))
00
相似问题