关于TensorFlow的LinearClassifier的预测
来源:5-5 预定义estimator使用
慕仰6273561
2020-03-31
问一下 就是这个linear_estimator.predict 之后得到一个generator object Estimator.predict at 0x150c2fad0 得到这个generator以后 我不知道怎么取出里面的值
这个linear_estimator就是LinearClassifier带入dataset以后跑出来的estimator
我用了list() 和 for 循环打印这个generator object Estimator.predict 也得不到值
我现在的tensorflow是2.1版本
写回答
1回答
-
正十七
2020-04-07
同学你好,input_fn应该是一个dataset:
predicted_value = linear_estimator.predict( input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1, shuffle = False)) counter = 0 for i in predicted_value: counter += 1 print(i) if counter > 10: break
打印出结果:
{'logits': array([0.68644077], dtype=float32), 'logistic': array([0.66517466], dtype=float32), 'probabilities': array([0.3348253 , 0.66517466], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([-0.33140332], dtype=float32), 'logistic': array([0.41789922], dtype=float32), 'probabilities': array([0.5821008 , 0.41789922], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([0.76469857], dtype=float32), 'logistic': array([0.682373], dtype=float32), 'probabilities': array([0.31762704, 0.682373 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([-1.6867031], dtype=float32), 'logistic': array([0.1562099], dtype=float32), 'probabilities': array([0.84379005, 0.1562099 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([1.8905865], dtype=float32), 'logistic': array([0.8688224], dtype=float32), 'probabilities': array([0.1311776, 0.8688224], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([2.1999667], dtype=float32), 'logistic': array([0.90024656], dtype=float32), 'probabilities': array([0.09975349, 0.90024656], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([0.98448634], dtype=float32), 'logistic': array([0.7279975], dtype=float32), 'probabilities': array([0.2720025, 0.7279975], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([-2.7839499], dtype=float32), 'logistic': array([0.05819768], dtype=float32), 'probabilities': array([0.9418023 , 0.05819768], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([1.3115588], dtype=float32), 'logistic': array([0.7877739], dtype=float32), 'probabilities': array([0.21222611, 0.7877739 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([0.8499193], dtype=float32), 'logistic': array([0.7005502], dtype=float32), 'probabilities': array([0.29944977, 0.7005502 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)} {'logits': array([0.8553443], dtype=float32), 'logistic': array([0.70168704], dtype=float32), 'probabilities': array([0.298313 , 0.70168704], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
00
相似问题