关于TensorFlow的LinearClassifier的预测

来源:5-5 预定义estimator使用

慕仰6273561

2020-03-31

问一下 就是这个linear_estimator.predict 之后得到一个generator object Estimator.predict at 0x150c2fad0 得到这个generator以后 我不知道怎么取出里面的值
这个linear_estimator就是LinearClassifier带入dataset以后跑出来的estimator
我用了list() 和 for 循环打印这个generator object Estimator.predict 也得不到值
我现在的tensorflow是2.1版本
图片描述

写回答

1回答

正十七

2020-04-07

同学你好,input_fn应该是一个dataset:

predicted_value = linear_estimator.predict(
    input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1, shuffle = False))

counter = 0
for i in predicted_value:
    counter += 1
    print(i)
    if counter > 10:
        break

打印出结果:

{'logits': array([0.68644077], dtype=float32), 'logistic': array([0.66517466], dtype=float32), 'probabilities': array([0.3348253 , 0.66517466], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-0.33140332], dtype=float32), 'logistic': array([0.41789922], dtype=float32), 'probabilities': array([0.5821008 , 0.41789922], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.76469857], dtype=float32), 'logistic': array([0.682373], dtype=float32), 'probabilities': array([0.31762704, 0.682373  ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-1.6867031], dtype=float32), 'logistic': array([0.1562099], dtype=float32), 'probabilities': array([0.84379005, 0.1562099 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([1.8905865], dtype=float32), 'logistic': array([0.8688224], dtype=float32), 'probabilities': array([0.1311776, 0.8688224], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([2.1999667], dtype=float32), 'logistic': array([0.90024656], dtype=float32), 'probabilities': array([0.09975349, 0.90024656], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.98448634], dtype=float32), 'logistic': array([0.7279975], dtype=float32), 'probabilities': array([0.2720025, 0.7279975], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-2.7839499], dtype=float32), 'logistic': array([0.05819768], dtype=float32), 'probabilities': array([0.9418023 , 0.05819768], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([1.3115588], dtype=float32), 'logistic': array([0.7877739], dtype=float32), 'probabilities': array([0.21222611, 0.7877739 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.8499193], dtype=float32), 'logistic': array([0.7005502], dtype=float32), 'probabilities': array([0.29944977, 0.7005502 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.8553443], dtype=float32), 'logistic': array([0.70168704], dtype=float32), 'probabilities': array([0.298313  , 0.70168704], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程