中文
模型训练与评估入门:损失函数、过拟合和准确率怎么理解
很多人第一次训练模型时,会把注意力放在“准确率是多少”。但要真正理解模型是否可靠,需要先知道训练到底在调整什么,损失函数表示什么,过拟合为什么会发生,以及测试集为什么必须保留。
这篇文章围绕模型训练与评估展开,重点讲清楚几个基础概念:参数、损失函数、训练轮次、过拟合、验证集、测试集和常见分类指标。
如果说上一篇机器学习流程解决的是“项目怎么组织”,这一篇解决的就是“训练过程是否可信”。
一、模型训练到底在训练什么
一个模型可以理解成一个带参数的函数:
预测结果 = model(输入特征, 参数)
训练前,参数可能是随机的,也可能有默认初始值。训练的目标就是不断调整这些参数,让模型输出更接近真实标签。
以一个非常简单的线性模型为例:
y = w1 * x1 + w2 * x2 + b
这里的 w1、w2 和 b 就是参数。训练过程会尝试找到更合适的参数组合。
二、损失函数是什么
模型需要一个标准来判断“预测得有多错”。这个标准就是损失函数。
对回归问题,常见损失可以是预测值和真实值之间的平方差:
loss = (y_true - y_pred) ** 2
对分类问题,常用的是交叉熵损失。你不需要一开始就推导公式,但要理解它的直觉:
模型越自信地给出错误答案,损失越大;模型越接近正确答案,损失越小。
训练模型时,算法会尝试让整体损失下降。
三、梯度下降的直觉
很多模型使用梯度下降或它的变体来更新参数。可以把它想成一个下山过程:
- 当前参数对应一个损失值
- 计算往哪个方向移动能让损失下降
- 参数朝这个方向移动一点
- 重复很多次
这里有一个重要超参数叫学习率。学习率太小,下降很慢;学习率太大,可能来回震荡甚至无法收敛。
new_weight = old_weight - learning_rate * gradient
这不是全部细节,但足够帮助你理解训练循环为什么需要反复执行。
四、epoch、batch 和 iteration
训练深度学习模型时,经常会看到三个词:
- epoch:完整看完一遍训练集
- batch:每次拿一小批样本计算损失和梯度
- iteration:一次参数更新
如果训练集有 1000 条样本,batch size 是 100,那么一个 epoch 里会有 10 次 iteration。
传统机器学习库不一定直接暴露这些词,但背后的思想相似:模型需要通过训练数据反复调整参数。
五、为什么会过拟合
过拟合指的是:模型在训练集上表现很好,但在新数据上表现明显变差。
常见原因包括:
- 模型太复杂,记住了训练数据里的细节和噪声
- 训练数据太少,无法代表真实场景
- 特征里包含了不该使用的信息,也就是数据泄漏
- 训练太久,但没有监控验证集表现
过拟合的危险在于:你只看训练集指标时会觉得模型很好,但上线或遇到新样本后效果很差。
六、训练集、验证集和测试集
为了更可靠地评估模型,通常会拆出三类数据:
- 训练集:用于训练参数
- 验证集:用于调参、选模型、观察是否过拟合
- 测试集:最后只用一次,用来估计最终泛化效果
如果数据量不大,也可以先用训练集和测试集两份数据。但要记住:测试集不应该被反复用于调参,否则它也会间接参与训练决策。
七、分类任务常见指标
分类问题不能只看准确率。下面几个指标经常一起出现:
- Accuracy:预测正确的比例
- Precision:预测为正类的样本中,有多少是真的正类
- Recall:真实正类中,有多少被模型找出来
- F1-score:precision 和 recall 的综合指标
如果是疾病筛查,漏掉真实阳性可能很严重,这时 recall 往往很重要。如果是自动封号,误伤正常用户代价很高,这时 precision 可能更重要。
八、混淆矩阵怎么读
混淆矩阵会把预测结果和真实标签交叉统计:
预测负类 预测正类
真实负类 TN FP
真实正类 FN TP
TP:正类预测对了TN:负类预测对了FP:负类被误判成正类FN:正类被误判成负类
看混淆矩阵的好处是,你不只知道模型错了多少,还能知道它主要错在哪一种方向。
九、评估模型时的检查清单
每次评估模型时,可以按下面几个问题检查:
- 测试集是否从训练中隔离出来
- 类别是否严重不平衡
- 是否只看了 accuracy
- 是否和简单基线模型比较过
- 错误样本是否被人工看过一部分
- 训练表现和测试表现差距是否过大
模型训练的核心不是把某个指标刷高,而是建立一个可信的判断过程。你需要知道模型为什么看起来有效,也要知道它在哪些情况下可能失效。
十、判断一次训练是否靠谱
一个比较靠谱的训练记录,至少应该包含这些信息:
- 训练集、验证集和测试集的划分方式
- 使用的模型、主要参数和随机种子
- 训练指标和测试指标,而不是只给一个最终分数
- 错误样本分析,尤其是代价最高的错误类型
- 和简单基线模型的比较
这些记录看起来琐碎,但它们能让你几天后重新检查实验时,仍然知道结果从哪里来。
十一、下一步读什么
英文
Model Training and Evaluation: Loss, Overfitting, and Accuracy
在独立页面打开When people train their first model, they often focus only on accuracy. To understand whether a model is reliable, you need to know what training adjusts, what a loss function measures, why overfitting happens, and why test data must stay separate.
This article explains the basics of model training and evaluation: parameters, loss functions, epochs, overfitting, validation data, test data, and common classification metrics.
If the previous article explained how to organize a machine learning project, this one explains how to decide whether the training process is trustworthy.
1. What Does Training Adjust?
A model can be viewed as a function with parameters:
prediction = model(input_features, parameters)
Before training, parameters may be random or initialized with default values. Training adjusts those parameters so model output becomes closer to the true labels.
A very simple linear model looks like this:
y = w1 * x1 + w2 * x2 + b
Here, w1, w2, and b are parameters. Training tries to find better values for them.
2. What Is a Loss Function?
The model needs a way to measure how wrong a prediction is. That measurement is the loss function.
For regression, a simple loss can be squared error:
loss = (y_true - y_pred) ** 2
For classification, cross-entropy loss is common. You do not need to derive the formula at the beginning, but the intuition matters:
A confidently wrong prediction receives a large loss. A prediction close to the correct answer receives a smaller loss.
During training, the algorithm tries to reduce the overall loss.
3. The Intuition Behind Gradient Descent
Many models use gradient descent or a variant of it to update parameters. Think of it as walking downhill:
- The current parameters produce a loss value
- The algorithm estimates which direction reduces loss
- The parameters move a small step in that direction
- The process repeats many times
An important hyperparameter is the learning rate. If it is too small, training is slow. If it is too large, training can bounce around or fail to converge.
new_weight = old_weight - learning_rate * gradient
This is not the full mathematical story, but it explains why training loops repeat parameter updates.
4. Epoch, Batch, and Iteration
Deep learning training often uses these terms:
- Epoch: one full pass through the training set
- Batch: a small group of samples used for one update step
- Iteration: one parameter update
If the training set has 1000 samples and the batch size is 100, one epoch contains 10 iterations.
Traditional machine learning libraries may not expose these terms directly, but the basic idea is similar: the model uses training data to adjust parameters.
5. Why Overfitting Happens
Overfitting means the model performs well on training data but much worse on new data.
Common causes include:
- The model is complex enough to memorize noise and details in the training set
- The training data is too small to represent the real problem
- The features contain information that should not be available, also called data leakage
- The model trains for too long without validation monitoring
The danger is that training metrics can look excellent while real-world performance is poor.
6. Training, Validation, and Test Data
For reliable evaluation, data is often split into three parts:
- Training set: used to fit parameters
- Validation set: used to tune settings, select models, and watch for overfitting
- Test set: used at the end to estimate final generalization
For small practice projects, a training/test split can be enough. But remember: the test set should not be used repeatedly for tuning, or it becomes part of the decision process.
7. Common Classification Metrics
Classification should not be judged by accuracy alone. These metrics often appear together:
- Accuracy: the proportion of correct predictions
- Precision: among predicted positives, how many are truly positive
- Recall: among true positives, how many were found
- F1-score: a combined measure of precision and recall
For medical screening, missing a real positive case may be costly, so recall may matter more. For automatic account blocking, falsely blocking normal users may be costly, so precision may matter more.
8. Reading a Confusion Matrix
A confusion matrix compares predicted labels with true labels:
predicted negative predicted positive
true negative TN FP
true positive FN TP
TP: a positive sample predicted correctlyTN: a negative sample predicted correctlyFP: a negative sample incorrectly predicted as positiveFN: a positive sample incorrectly predicted as negative
The advantage of a confusion matrix is that it shows not only how many mistakes happened, but also which direction those mistakes went.
9. Evaluation Checklist
When evaluating a model, check these questions:
- Was the test set isolated from training?
- Are the classes heavily imbalanced?
- Did you look beyond accuracy?
- Was the model compared with a simple baseline?
- Did you inspect some wrong predictions manually?
- Is the gap between training performance and test performance too large?
The point of training is not merely to push one metric upward. The point is to build a trustworthy evaluation process and understand when the model is likely to fail.
10. What a Trustworthy Training Record Includes
A useful training record should include at least these details:
- How training, validation, and test data were split
- The model, important parameters, and random seed
- Training metrics and test metrics, not just one final score
- Error analysis, especially for the most costly error types
- Comparison against a simple baseline model
These notes may look small, but they make the experiment auditable when you return to it later.
11. What to Read Next
The previous article is Machine Learning Workflow. After training and evaluation are clear, continue with Neural Network Basics to connect parameters with multi-layer function composition.
很多人第一次训练模型时,会把注意力放在“准确率是多少”。但要真正理解模型是否可靠,需要先知道训练到底在调整什么,损失函数表示什么,过拟合为什么会发生,以及测试集为什么必须保留。
这篇文章围绕模型训练与评估展开,重点讲清楚几个基础概念:参数、损失函数、训练轮次、过拟合、验证集、测试集和常见分类指标。
如果说上一篇机器学习流程解决的是“项目怎么组织”,这一篇解决的就是“训练过程是否可信”。
一、模型训练到底在训练什么
一个模型可以理解成一个带参数的函数:
预测结果 = model(输入特征, 参数)
训练前,参数可能是随机的,也可能有默认初始值。训练的目标就是不断调整这些参数,让模型输出更接近真实标签。
以一个非常简单的线性模型为例:
y = w1 * x1 + w2 * x2 + b
这里的 w1、w2 和 b 就是参数。训练过程会尝试找到更合适的参数组合。
二、损失函数是什么
模型需要一个标准来判断“预测得有多错”。这个标准就是损失函数。
对回归问题,常见损失可以是预测值和真实值之间的平方差:
loss = (y_true - y_pred) ** 2
对分类问题,常用的是交叉熵损失。你不需要一开始就推导公式,但要理解它的直觉:
模型越自信地给出错误答案,损失越大;模型越接近正确答案,损失越小。
训练模型时,算法会尝试让整体损失下降。
三、梯度下降的直觉
很多模型使用梯度下降或它的变体来更新参数。可以把它想成一个下山过程:
- 当前参数对应一个损失值
- 计算往哪个方向移动能让损失下降
- 参数朝这个方向移动一点
- 重复很多次
这里有一个重要超参数叫学习率。学习率太小,下降很慢;学习率太大,可能来回震荡甚至无法收敛。
new_weight = old_weight - learning_rate * gradient
这不是全部细节,但足够帮助你理解训练循环为什么需要反复执行。
四、epoch、batch 和 iteration
训练深度学习模型时,经常会看到三个词:
- epoch:完整看完一遍训练集
- batch:每次拿一小批样本计算损失和梯度
- iteration:一次参数更新
如果训练集有 1000 条样本,batch size 是 100,那么一个 epoch 里会有 10 次 iteration。
传统机器学习库不一定直接暴露这些词,但背后的思想相似:模型需要通过训练数据反复调整参数。
五、为什么会过拟合
过拟合指的是:模型在训练集上表现很好,但在新数据上表现明显变差。
常见原因包括:
- 模型太复杂,记住了训练数据里的细节和噪声
- 训练数据太少,无法代表真实场景
- 特征里包含了不该使用的信息,也就是数据泄漏
- 训练太久,但没有监控验证集表现
过拟合的危险在于:你只看训练集指标时会觉得模型很好,但上线或遇到新样本后效果很差。
六、训练集、验证集和测试集
为了更可靠地评估模型,通常会拆出三类数据:
- 训练集:用于训练参数
- 验证集:用于调参、选模型、观察是否过拟合
- 测试集:最后只用一次,用来估计最终泛化效果
如果数据量不大,也可以先用训练集和测试集两份数据。但要记住:测试集不应该被反复用于调参,否则它也会间接参与训练决策。
七、分类任务常见指标
分类问题不能只看准确率。下面几个指标经常一起出现:
- Accuracy:预测正确的比例
- Precision:预测为正类的样本中,有多少是真的正类
- Recall:真实正类中,有多少被模型找出来
- F1-score:precision 和 recall 的综合指标
如果是疾病筛查,漏掉真实阳性可能很严重,这时 recall 往往很重要。如果是自动封号,误伤正常用户代价很高,这时 precision 可能更重要。
八、混淆矩阵怎么读
混淆矩阵会把预测结果和真实标签交叉统计:
预测负类 预测正类
真实负类 TN FP
真实正类 FN TP
TP:正类预测对了TN:负类预测对了FP:负类被误判成正类FN:正类被误判成负类
看混淆矩阵的好处是,你不只知道模型错了多少,还能知道它主要错在哪一种方向。
九、评估模型时的检查清单
每次评估模型时,可以按下面几个问题检查:
- 测试集是否从训练中隔离出来
- 类别是否严重不平衡
- 是否只看了 accuracy
- 是否和简单基线模型比较过
- 错误样本是否被人工看过一部分
- 训练表现和测试表现差距是否过大
模型训练的核心不是把某个指标刷高,而是建立一个可信的判断过程。你需要知道模型为什么看起来有效,也要知道它在哪些情况下可能失效。
十、判断一次训练是否靠谱
一个比较靠谱的训练记录,至少应该包含这些信息:
- 训练集、验证集和测试集的划分方式
- 使用的模型、主要参数和随机种子
- 训练指标和测试指标,而不是只给一个最终分数
- 错误样本分析,尤其是代价最高的错误类型
- 和简单基线模型的比较
这些记录看起来琐碎,但它们能让你几天后重新检查实验时,仍然知道结果从哪里来。
十一、下一步读什么
搜索问题
常见问题
这篇文章适合谁读?
这篇文章适合想用 入门 难度理解“模型训练与评估入门:损失函数、过拟合和准确率怎么理解”的读者,预计阅读时间约 9 分钟,重点覆盖 Model Training, Metrics, Evaluation。
读完后下一步应该看什么?
推荐下一步阅读“过拟合和欠拟合怎么解决:机器学习模型调优实战指南”,这样可以把当前知识点接到更完整的学习路线里。
这篇文章有没有可运行代码或配套资源?
这篇文章以解释为主,文末相关阅读会继续指向更接近实战的代码和资源页面。
这篇文章和整个网站的学习路线有什么关系?
它会通过文章上下文、学习路线、资源库和项目时间线连接到同一主题下的其他内容。
文章上下文
人工智能项目
从 AI、机器学习、训练评估、神经网络到 Python 小实战、手写数字识别、CIFAR-10 CNN、对抗性流量防御和 AI 安全攻防,按顺序建立基础。
项目时间线
已发布文章
- 人工智能基础学习路线:先理解什么是 AI、机器学习和深度学习 面向有编程基础的读者,梳理 AI、机器学习、深度学习的关系,并给出可执行的人工智能基础学习路线。
- 机器学习完整流程:从数据、特征到模型预测 从工程视角拆解机器学习完整流程:定义问题、理解数据、处理特征、训练模型、预测和评估。
- 机器学习算法怎么选:分类、回归、聚类和推荐场景对照表 用任务类型、数据规模、解释性和部署成本选择机器学习算法,覆盖逻辑回归、决策树、随机森林、K-means 和表格数据基线模型。
- 特征工程入门实战:用 scikit-learn 处理缺失值、类别变量和数值标准化 用 scikit-learn Pipeline 和 ColumnTransformer 完成特征工程,处理缺失值、类别变量、数值标准化,并避免数据泄漏。
- 模型训练与评估入门:损失函数、过拟合和准确率怎么理解 讲清楚模型训练中的参数、损失函数、梯度下降、过拟合,以及准确率、召回率、F1 等分类评估指标。
- 过拟合和欠拟合怎么解决:机器学习模型调优实战指南 用训练分数和验证分数判断过拟合与欠拟合,并通过模型复杂度、正则化、交叉验证和特征工程调整机器学习模型。
- 神经网络基础:从感知机到多层网络 从一个神经元讲起,解释权重、偏置、激活函数、前向传播、反向传播和典型神经网络训练循环。
- Python 人工智能小实战:用 scikit-learn 完成一个分类任务 使用 scikit-learn 内置教学数据集跑通一个分类任务,覆盖数据加载、拆分、标准化、训练、预测、评估和实验记录。
- 手写数字识别项目入门:先读懂 train.csv、test.csv 和标签结构 从项目文件结构入手,读懂手写数字训练集、测试集、标签列和 784 维像素输入,为后续 C 分类器和实验台打基础。
- 用 C 实现手写数字 Softmax 分类器:从 784 维像素到 submission.csv 结合当前项目源码,讲清楚 softmax 多分类、损失函数、梯度更新、混淆矩阵输出,以及 submission.csv 的生成过程。
- 手写数字实验记录:怎么把离线分类项目接进浏览器实验台 解释浏览器实验台为什么采用轻量预训练模型、它和离线 C 项目的关系,以及如何用样本浏览和手绘输入理解预测结果。
- CIFAR-10 Tiny CNN 教程:用 C 语言实现小型卷积神经网络图像分类 用单文件 C 程序完成 CIFAR-10 小型 CNN 图像分类,讲解数据格式、网络结构、训练命令、loss、accuracy、常见错误和改进方向。
- 构建高熵流量防御:基于 Python 的连接层白噪声混淆与对抗性机器学习实践 以 mld_chaffing_v2.py 虚幻镜项目为例,讲解加密元数据泄漏、信息熵、分布距离、混淆矩阵、空闲窗口微脉冲和性能测试取舍。
- AI 安全威胁建模:用 NIST AML、MITRE ATLAS 和 OWASP 建立攻防地图 用 NIST Adversarial ML、MITRE ATLAS 和 OWASP LLM Top 10 建立 AI 安全威胁模型,覆盖资产、攻击面、证据和剩余风险。
- 对抗样本与鲁棒评估:从 FGSM 公式到 scikit-learn 数字分类实验 从 FGSM 公式解释对抗样本,用 scikit-learn digits toy 实验评估 clean accuracy、perturbed accuracy 和扰动预算。
- 数据投毒与后门攻击防御:污染率、触发器和训练管线隔离 用 toy digits 实验解释数据投毒、后门触发器、attack success rate、数据来源审计和训练管线隔离。
- 模型隐私与模型窃取风险:成员推断、模型抽取和输出接口防护 用本地 toy 实验解释成员推断、模型抽取、membership AUC、surrogate fidelity、输出最小化和查询治理。
- LLM/RAG/Agent 安全:Prompt Injection、工具权限和边界感知防护 从 RAG 和 Agent 架构解释 prompt injection、外部数据降权、工具 allowlist、人工审批和边界感知防护。
- 人工智能 NLP 基础:词袋模型与 TF-IDF 详解 介绍自然语言处理中最基础的文本表示方法:词袋模型(Bag of Words)与 TF-IDF,理解它们的工作原理及优缺点。
- 循环神经网络 (RNN) 基础:处理序列数据的记忆力 理解 RNN 的核心思想、隐藏状态的作用,以及它在处理自然语言序列任务时的优势与挑战。
- Transformer 与自注意力机制:AI 领域的革命性突破 深入浅出地讲解 Transformer 架构的核心:自注意力机制(Self-Attention)及其运作方式。
- 用 C 从零实现 CIFAR-10 Tiny CNN:卷积、池化和反向传播 基于实际 cifar10_tiny_cnn.c 项目,讲解 CIFAR-10 数据格式、3x3 卷积、ReLU、最大池化、全连接层、softmax、反向传播和本地运行方式。
已公开资源
- Python AI 小实战代码说明 文章内包含可直接复制运行的 scikit-learn 分类脚本。
- digit_softmax_classifier.c 手写数字 softmax 分类器的 C 语言源码。
- train.csv.zip 手写数字训练集压缩包,包含 42000 条带标签样本。
- test.csv.zip 手写数字测试集压缩包,包含 28000 条待预测样本。
- sample_submission.csv 官方提交格式示例,可直接对照最终输出字段。
- submission.csv 当前 C 项目跑出的预测结果文件。
- digit-playground-model.json 浏览器实验台使用的轻量 softmax 演示模型与样本。
- digit-sample-grid.svg 从训练集中抽取的小型手写数字预览网格。
- 手写数字项目打包下载 包含源码、压缩数据、提交文件、浏览器模型和样本预览图。
- cifar10_tiny_cnn.c 源码 单文件 C 语言 tiny CNN,包含 CIFAR-10 读取、卷积、池化、softmax 和反向传播。
- model_weights.bin 样例权重 一次本地小样本运行生成的模型权重文件。
- test_predictions.csv 预测样例 CIFAR-10 tiny CNN 输出的测试预测样例。
- CNN 项目说明 PDF 配套 CNN 项目说明材料。
- 虚幻镜脱敏代码骨架 去除控制口令、真实节点和目标列表后的 mld_chaffing_v2.py 控制流程说明。
- 虚幻镜压力测试记录模板 用于记录 CPU、内存、线程峰值、微脉冲速率、延迟和错误数的脱敏 CSV 模板。
- 虚幻镜分类器评估模板 用于记录 TP、FN、FP、TN、accuracy、precision、recall、F1、ROC-AUC、熵和 JS 散度的 CSV 模板。
- 虚幻镜资源说明 说明公开资源为何只提供脱敏代码、测试模板和架构笔记。
- AI Security Lab 说明 说明 AI 安全攻防系列的安全边界、安装命令和 quick-run 实验。
- AI Security Lab 完整实验包 包含安全 toy scripts、结果 CSV、风险登记表、攻防矩阵和架构图。
- AI 安全风险登记表 面向 AI 威胁建模和上线评审的 CSV 风险登记模板。
- AI 攻防矩阵 把攻击面、toy demo、指标和防护控制映射到一张 CSV 表。
- AI Security Lab 架构图 展示威胁建模、鲁棒评估、数据完整性、模型隐私和 RAG 防护之间的关系。
- FGSM digits 鲁棒评估脚本 本地 digits 分类器的 FGSM-style 扰动和准确率下降实验。
- 数据投毒与后门 toy 脚本 用 digits 数据演示污染率、触发器和 attack success rate。
- 模型隐私与抽取 toy 脚本 输出 membership AUC、target accuracy、surrogate fidelity 和 surrogate accuracy。
- RAG prompt injection guard toy 脚本 用确定性 toy agent 演示外部数据降权和工具权限阻断。
- 深度学习专题分享图 用于分享深度学习 / CNN 专题页的 1200x630 SVG 图。
- 从零实现机器学习分享图 用于分享 K-means、Iris 和机器学习流程专题页的 1200x630 SVG 图。
- 学生 AI 项目分享图 用于分享手写数字、C 分类器和浏览器实验台专题页的 1200x630 SVG 图。
- CNN 卷积扫描动画 Remotion 生成的 8 秒短动画,展示 3x3 卷积核如何扫描输入并形成特征图。
当前学习路线
- 人工智能基础学习路线 学习路线节点
- 机器学习完整流程 学习路线节点
- 机器学习算法怎么选 学习路线节点
- 特征工程入门实战 学习路线节点
- 模型训练与评估入门 学习路线节点
- 过拟合和欠拟合怎么解决 学习路线节点
- 神经网络基础 学习路线节点
- Transformer 自注意力机制 学习路线节点
- LLM 可视化教学台 学习路线节点
- Python 人工智能小实战 学习路线节点
- 手写数字数据结构入门 学习路线节点
- 用 C 实现手写数字 Softmax 分类器 学习路线节点
- 手写数字实验台说明 学习路线节点
- CIFAR-10 Tiny CNN 教程 学习路线节点
- 高熵流量防御实验 学习路线节点
- AI 安全威胁建模 学习路线节点
- 对抗样本与鲁棒评估 学习路线节点
- 数据投毒与后门防御 学习路线节点
- 模型隐私与模型抽取防护 学习路线节点
- LLM/RAG/Agent 安全 学习路线节点
下一步计划
- 补充更多图像分类和误差分析案例
- 把常见指标整理成速查表
- 继续补充 AI 安全防御实验记录
