计算准确率时出问题
来源:7-5 实战(二)

慕码人1499
2025-04-15
代码如下:
#load the mnist data
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
#evaluate the model
y_train_predict=(mlp.predict(x_train_normal)>0.5).astype(int)
print(y_train_predict)
from sklearn.metrics import accuracy_score
accuracy_train=accuracy_score(y_train,y_train_predict)
print(accuracy_train)
报错如下:
ValueError Traceback (most recent call last)
Cell In[23], line 2
1 from sklearn.metrics import accuracy_score
----> 2 accuracy_train=accuracy_score(y_train,y_train_predict)
3 print(accuracy_train)
File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\utils_param_validation.py:216, in validate_params..decorator..wrapper(*args, **kwargs)
210 try:
211 with config_context(
212 skip_parameter_validation=(
213 prefer_skip_nested_validation or global_skip_validation
214 )
215 ):
–> 216 return func(*args, **kwargs)
217 except InvalidParameterError as e:
218 # When the function is just a wrapper around an estimator, we allow
219 # the function to delegate validation to the estimator, but we replace
220 # the name of the estimator by the name of the function in the error
221 # message to avoid confusion.
222 msg = re.sub(
223 r"parameter of \w+ must be",
224 f"parameter of {func.qualname} must be",
225 str(e),
226 )
File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\metrics_classification.py:227, in accuracy_score(y_true, y_pred, normalize, sample_weight)
225 # Compute accuracy for each possible representation
226 y_true, y_pred = attach_unique(y_true, y_pred)
–> 227 y_type, y_true, y_pred = _check_targets(y_true, y_pred)
228 check_consistent_length(y_true, y_pred, sample_weight)
230 if y_type.startswith(“multilabel”):
File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\metrics_classification.py:107, in _check_targets(y_true, y_pred)
104 y_type = {“multiclass”}
106 if len(y_type) > 1:
–> 107 raise ValueError(
108 “Classification metrics can’t handle a mix of {0} and {1} targets”.format(
109 type_true, type_pred
110 )
111 )
113 # We can’t have more than one value on y_type => The set is no more needed
114 y_type = y_type.pop()
ValueError: Classification metrics can’t handle a mix of multiclass and multilabel-indicator targets
怎么解决这个问题呢?谢谢
1回答
-
flare_zhao
6天前
# 使用 argmax 获取概率最大的类别索引 y_train_predict = mlp.predict(x_train_normal).argmax(axis=1)
011天前
相似问题