关于模型的预测

来源:8-20 训练流程代码

Shuren_Yu

2021-02-12

老师,看了您的最后一个预测的代码,如果想预测一张自己的图片,而不是数据集中随机挑选的测试图片,应该在image_caption_eval.ipynb代码中做怎样的调整呢?我尝试了使用feature_extraction.ipynb的代码提取了想要预测图片的特征,同样打包成了pickle文件,但是导入image_caption_eval.ipynb的代码中预测时就会出错,因此非常想知道该如何调整image_caption_eval.ipynb以达到预测任意一张图片的目的。望赐教。

写回答

1回答

正十七

2021-02-24

你需要解构ImageCaptionData这个类。在这里你需要的数据就是Image经过vgg提取到的特征数据而已。

在如下的代码中,只有single_img_feature是模型所需要的输入,所以对于一张图像,你需要抽取特征得到single_img_features, 然后就照着运行就可以了啊。

with tf.Session() as sess:
    sess.run(init_op)
    logging.info("[*] Reading checkpoint ...")
    ckpt = tf.train.get_checkpoint_state(output_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(output_dir, ckpt_name))
        logging.info("[*] Success Read Checkpoint From %s" % (ckpt_name))    else:        raise Exception("[*] Failed load checkpoint")    
    for i in range(test_examples):
        single_img_features, single_sentence_ids, single_weights, single_img_name = caption_data.next(hps.batch_size)
        print(single_img_name)

        pprint.pprint(img_name_to_tokens[single_img_name[0]])
        pprint.pprint(img_name_to_token_ids[single_img_name[0]])

        embed_img_val = sess.run(embed_img, feed_dict={img_feature: single_img_features})

        state_val = np.zeros((1, num_hidden_states))
        embed_input_val = embed_img_val
        generated_sequence = []
        for j in range(hps.num_timesteps):
            logits_val, state_val = sess.run([logits, output_state],
                                             feed_dict = {
                                                 embed_input: embed_input_val,
                                                 input_state: state_val
                                             })
            predicted_word_id = np.argmax(logits_val[0])
            generated_sequence.append(predicted_word_id)
            embed_input_val = sess.run(embed_word,
                                       feed_dict={word: [[predicted_word_id]]})
        pprint.pprint("generated words: ")
        pprint.pprint(generated_sequence)
        pprint.pprint(vocab.decode(generated_sequence))


0
0

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程