# 基于朴素贝叶斯的垃圾邮件分类器实现


背景介绍

随着电子邮件的普及,垃圾邮件(Spam)已成为困扰用户的常见问题——不仅浪费存储空间和带宽,还可能携带恶意链接或诈骗信息。朴素贝叶斯算法因其简单高效、适合处理高维文本数据的特性,成为垃圾邮件分类的经典解决方案。

本文将手动实现一个不依赖第三方机器学习库的垃圾邮件分类器,涵盖文本预处理模型训练实时预测三个核心模块,帮助读者深入理解朴素贝叶斯算法的应用逻辑。

思路分析

1. 核心流程

垃圾邮件分类器的实现分为两大阶段:
训练阶段:从本地文件夹读取标记数据(垃圾/正常邮件),通过文本预处理生成特征,计算朴素贝叶斯所需的概率参数;
预测阶段:对输入文本进行同样预处理,利用训练好的模型计算后验概率,输出分类结果和置信度。

2. 文本预处理步骤

  • 小写转换:统一文本大小写,避免“Free”和“free”被视为不同单词;
  • 标点去除:过滤标点符号(如!$),减少噪声干扰;
  • 停用词过滤:移除无意义的常用词(如“the”、“is”),聚焦关键信息;
  • 词频统计:统计每个单词在不同类别中的出现次数,为概率计算提供基础。

3. 朴素贝叶斯核心逻辑

  • 先验概率:计算垃圾邮件(P(spam))和正常邮件(P(ham))的先验概率,即两类样本在总数据中的占比;
  • 条件概率:计算每个单词在垃圾/正常邮件中的条件概率(P(word|spam)P(word|ham)),使用拉普拉斯平滑避免零概率问题;
  • 后验概率:对输入文本,结合先验概率和条件概率计算后验概率(P(spam|text)P(ham|text)),选择概率较大的类别作为结果。

代码实现

完整代码(Python)

import os
import math
import collections
import string
import pickle

def preprocess_text(text, stop_words):
    """文本预处理:小写转换、标点去除、停用词过滤"""
    # 1. 转换为小写
    text = text.lower()
    # 2. 去除标点符号
    translator = str.maketrans('', '', string.punctuation)
    text = text.translate(translator)
    # 3. 分割为单词列表
    words = text.split()
    # 4. 过滤停用词和空字符串
    filtered_words = [word for word in words if word not in stop_words and word.strip() != ""]
    return filtered_words

def load_data(folder_path, label, stop_words):
    """加载指定文件夹中的邮件数据,返回预处理后的单词列表和标签"""
    data = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.txt'):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                processed_words = preprocess_text(content, stop_words)
                data.append((processed_words, label))
    return data

def train_model(spam_folder, ham_folder, stop_words):
    """训练朴素贝叶斯模型,返回模型参数"""
    # 加载数据
    spam_data = load_data(spam_folder, 'spam', stop_words)
    ham_data = load_data(ham_folder, 'ham', stop_words)

    # 计算先验概率
    total_spam = len(spam_data)
    total_ham = len(ham_data)
    total_samples = total_spam + total_ham
    prior_spam = total_spam / total_samples
    prior_ham = total_ham / total_samples

    # 统计词频(垃圾邮件/正常邮件)
    spam_word_count = collections.defaultdict(int)  # 垃圾邮件中每个单词的出现次数
    ham_word_count = collections.defaultdict(int)   # 正常邮件中每个单词的出现次数
    spam_total = 0  # 垃圾邮件总单词数
    ham_total = 0   # 正常邮件总单词数
    vocab = set()   # 词汇表(所有出现过的单词)

    for words, _ in spam_data:
        for word in words:
            spam_word_count[word] += 1
            spam_total += 1
            vocab.add(word)

    for words, _ in ham_data:
        for word in words:
            ham_word_count[word] += 1
            ham_total += 1
            vocab.add(word)

    vocab_size = len(vocab)

    # 封装模型参数
    model = {
        'prior_spam': prior_spam,
        'prior_ham': prior_ham,
        'spam_word_count': spam_word_count,
        'ham_word_count': ham_word_count,
        'spam_total': spam_total,
        'ham_total': ham_total,
        'vocab_size': vocab_size,
        'stop_words': stop_words
    }

    return model, total_spam, total_ham, vocab_size

def predict(text, model):
    """预测输入文本的类别和置信度"""
    # 预处理输入文本
    processed_words = preprocess_text(text, model['stop_words'])

    # 获取模型参数
    prior_spam = model['prior_spam']
    prior_ham = model['prior_ham']
    spam_word_count = model['spam_word_count']
    ham_word_count = model['ham_word_count']
    spam_total = model['spam_total']
    ham_total = model['ham_total']
    vocab_size = model['vocab_size']

    # 计算后验概率(使用对数避免数值下溢)
    log_prob_spam = math.log(prior_spam)
    log_prob_ham = math.log(prior_ham)

    for word in processed_words:
        # 垃圾邮件条件概率(拉普拉斯平滑)
        count_spam = spam_word_count.get(word, 0)
        prob_word_spam = (count_spam + 1) / (spam_total + vocab_size)
        log_prob_spam += math.log(prob_word_spam)

        # 正常邮件条件概率(拉普拉斯平滑)
        count_ham = ham_word_count.get(word, 0)
        prob_word_ham = (count_ham + 1) / (ham_total + vocab_size)
        log_prob_ham += math.log(prob_word_ham)

    # 转换为原始概率并计算置信度
    prob_spam = math.exp(log_prob_spam)
    prob_ham = math.exp(log_prob_ham)
    total_prob = prob_spam + prob_ham
    confidence = (max(prob_spam, prob_ham) / total_prob) * 100

    # 确定分类结果
    result = "垃圾邮件" if prob_spam > prob_ham else "正常邮件"
    return result, round(confidence, 1)

def main():
    """主函数:处理用户交互"""
    # 内置停用词表(可扩展为自定义导入)
    stop_words = {
        'the', 'is', 'and', 'you', 'to', 'in', 'a', 'of', 'for', 'on', 'with', 'at', 'by',
        'this', 'that', 'it', 'we', 'are', 'be', 'will', 'can', 'has', 'have', 'had', 'was',
        'were', 'i', 'me', 'my', 'your', 'yours', 'he', 'she', 'they', 'them', 'his', 'her',
        'their', 'our', 'us', 'as', 'but', 'or', 'so', 'if', 'then', 'than', 'from', 'up',
        'down', 'out', 'into', 'over', 'under', 'again', 'further', 'once', 'here', 'there',
        'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most',
        'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than',
        'too', 'very', 's', 't', 'don', 'now'
    }

    print("=== 垃圾邮件分类器 ===")
    print("1. 训练模型")
    print("2. 预测邮件")
    choice = input("请选择操作(1/2):")

    if choice == '1':
        # 训练流程
        spam_folder = input("请输入垃圾邮件文件夹路径:")
        ham_folder = input("请输入正常邮件文件夹路径:")

        if not os.path.isdir(spam_folder) or not os.path.isdir(ham_folder):
            print("错误:文件夹路径无效!")
            return

        model, total_spam, total_ham, vocab_size = train_model(spam_folder, ham_folder, stop_words)
        # 保存模型到文件
        with open('spam_model.pkl', 'wb') as f:
            pickle.dump(model, f)

        print("\n训练完成!")
        print(f"- 垃圾邮件样本数:{total_spam}")
        print(f"- 正常邮件样本数:{total_ham}")
        print(f"- 词汇表大小:{vocab_size}")
        print("模型已保存到 spam_model.pkl")

    elif choice == '2':
        # 预测流程
        if not os.path.exists('spam_model.pkl'):
            print("错误:模型未训练,请先执行训练操作!")
            return

        with open('spam_model.pkl', 'rb') as f:
            model = pickle.load(f)

        text = input("\n请输入邮件文本:")
        result, confidence = predict(text, model)
        print(f"\n分类结果:{result}")
        print(f"置信度:{confidence}%")

    else:
        print("错误:无效选择!")

if __name__ == "__main__":
    main()

代码说明

1. 关键模块解析

  • 文本预处理preprocess_text函数实现了小写转换、标点去除和停用词过滤,为后续模型训练提供干净的输入;
  • 模型训练train_model函数统计两类邮件的词频,计算先验概率和条件概率,并使用拉普拉斯平滑解决零概率问题;
  • 预测模块predict函数通过对数概率计算后验概率,避免数值下溢,并输出分类结果和置信度;
  • 交互逻辑main函数提供命令行交互,支持模型训练(保存到文件)和实时预测。

2. 拉普拉斯平滑

当某个单词在训练集中未出现时,其条件概率会为0,导致整个后验概率为0。拉普拉斯平滑通过在分子加1、分母加词汇表大小,确保概率不为零:
[ P(w|c) = \frac{count(w,c) + 1}{total(c) + vocab_size} ]
其中,count(w,c)是单词w在类别c中的出现次数,total(c)是类别c的总单词数,vocab_size是词汇表大小。

总结

本文实现的垃圾邮件分类器覆盖了朴素贝叶斯算法的核心应用场景,通过手动编码加深了对算法的理解。未来可扩展的方向包括:
支持中文:添加中文分词(如jieba)和中文停用词表;
GUI界面:使用Tkinter或PyQt构建可视化交互界面;
模型优化:引入TF-IDF加权、扩展停用词表或使用更复杂的贝叶斯变体(如多项式朴素贝叶斯);
批量预测:支持对文件夹中所有邮件进行批量分类。

该项目适合中级以下开发者学习实践,帮助快速掌握文本分类和朴素贝叶斯算法的应用。


发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注