计算准确率时出问题

来源:7-5 实战(二)

慕码人1499

2025-04-15

代码如下:
#load the mnist data
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

#evaluate the model
y_train_predict=(mlp.predict(x_train_normal)>0.5).astype(int)
print(y_train_predict)

from sklearn.metrics import accuracy_score
accuracy_train=accuracy_score(y_train,y_train_predict)
print(accuracy_train)

报错如下:

ValueError Traceback (most recent call last)
Cell In[23], line 2
1 from sklearn.metrics import accuracy_score
----> 2 accuracy_train=accuracy_score(y_train,y_train_predict)
3 print(accuracy_train)

File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\utils_param_validation.py:216, in validate_params..decorator..wrapper(*args, **kwargs)
210 try:
211 with config_context(
212 skip_parameter_validation=(
213 prefer_skip_nested_validation or global_skip_validation
214 )
215 ):
–> 216 return func(*args, **kwargs)
217 except InvalidParameterError as e:
218 # When the function is just a wrapper around an estimator, we allow
219 # the function to delegate validation to the estimator, but we replace
220 # the name of the estimator by the name of the function in the error
221 # message to avoid confusion.
222 msg = re.sub(
223 r"parameter of \w+ must be",
224 f"parameter of {func.qualname} must be",
225 str(e),
226 )

File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\metrics_classification.py:227, in accuracy_score(y_true, y_pred, normalize, sample_weight)
225 # Compute accuracy for each possible representation
226 y_true, y_pred = attach_unique(y_true, y_pred)
–> 227 y_type, y_true, y_pred = _check_targets(y_true, y_pred)
228 check_consistent_length(y_true, y_pred, sample_weight)
230 if y_type.startswith(“multilabel”):

File F:\aconda20241219\envs\imooc_ai\Lib\site-packages\sklearn\metrics_classification.py:107, in _check_targets(y_true, y_pred)
104 y_type = {“multiclass”}
106 if len(y_type) > 1:
–> 107 raise ValueError(
108 “Classification metrics can’t handle a mix of {0} and {1} targets”.format(
109 type_true, type_pred
110 )
111 )
113 # We can’t have more than one value on y_type => The set is no more needed
114 y_type = y_type.pop()

ValueError: Classification metrics can’t handle a mix of multiclass and multilabel-indicator targets
怎么解决这个问题呢?谢谢

写回答

1回答

flare_zhao

6天前

# 使用 argmax 获取概率最大的类别索引
y_train_predict = mlp.predict(x_train_normal).argmax(axis=1)


0
1
慕码人1499
非常感谢!您能否告诉我该怎么学习这些知识呢?
1天前
共1条回复

Python3入门人工智能 掌握机器学习+深度学习

人工智能基础全方位讲解,构建完整人工智能知识体系,带你入门AI

2000 学习 · 605 问题

查看课程