Skip to content

AI写单元测试

单元测试是保证代码质量的重要手段,但很多开发者因为"太麻烦"而跳过。AI可以帮你自动生成测试,让写测试变成一件轻松的事。

为什么单元测试很重要?

没有测试的代码:
- 改一个地方,不知道会不会影响其他功能
- 不敢重构,怕改坏
- 每次发布都提心吊胆

有测试的代码:
- 改完跑一下测试,心里有数
- 大胆重构,测试会告诉你哪里错了
- 发布有信心

用AI生成测试的三种方式

方式一:让AI为现有代码写测试

请为以下函数编写单元测试:

<code>
def calculate_discount(price: float, member_level: str) -> float:
    """计算会员折扣价格"""
    discounts = {
        "普通会员": 0.95,
        "银卡会员": 0.90,
        "金卡会员": 0.85,
        "钻石会员": 0.80
    }
    discount = discounts.get(member_level, 1.0)
    return price * discount
</code>

要求:
1. 使用pytest框架
2. 覆盖正常情况和边界情况
3. 包含参数化测试

方式二:测试驱动开发(TDD)

先让AI帮你设计测试,再写代码:

我要实现一个密码强度验证函数,需求如下:
- 至少8个字符
- 包含大写字母
- 包含小写字母
- 包含数字
- 包含特殊字符

请先帮我写测试用例,我再根据测试写代码。

方式三:补充缺失的测试

这个项目测试覆盖率只有30%,请分析哪些关键函数缺少测试,
并为它们补充单元测试。

AI生成的测试示例

输入

python
def divide(a: float, b: float) -> float:
    """除法运算
    
    Args:
        a: 被除数
        b: 除数
        
    Returns:

        
    Raises:
        ValueError: 当除数为0时
    """
    if b == 0:
        raise ValueError("除数不能为0")
    return a / b

AI生成的测试

python
import pytest
from calculator import divide


class TestDivide:
    """divide函数的单元测试"""
    
    def test_normal_division(self):
        """测试正常除法"""
        assert divide(10, 2) == 5.0
        assert divide(9, 3) == 3.0
        assert divide(7, 2) == 3.5
    
    def test_negative_numbers(self):
        """测试负数除法"""
        assert divide(-10, 2) == -5.0
        assert divide(10, -2) == -5.0
        assert divide(-10, -2) == 5.0
    
    def test_float_result(self):
        """测试浮点数结果"""
        assert divide(1, 3) == pytest.approx(0.333333, rel=1e-5)
        assert divide(10, 4) == 2.5
    
    def test_zero_divisor(self):
        """测试除数为0的情况"""
        with pytest.raises(ValueError) as excinfo:
            divide(10, 0)
        assert "除数不能为0" in str(excinfo.value)
    
    def test_zero_dividend(self):
        """测试被除数为0的情况"""
        assert divide(0, 5) == 0.0
    
    @pytest.mark.parametrize("a,b,expected", [
        (100, 10, 10.0),
        (50, 5, 10.0),
        (33, 3, 11.0),
        (1, 1, 1.0),
    ])
    def test_parametrized(self, a: float, b: float, expected: float):
        """参数化测试"""
        assert divide(a, b) == expected

测试用例设计原则

原则一:覆盖正常情况

测试函数的主要功能是否正常工作:

python
def test_normal_case():
    """测试正常情况"""
    result = calculate_discount(100, "金卡会员")
    assert result == 85.0

原则二:覆盖边界情况

测试极端输入和边界条件:

python
def test_edge_cases():
    """测试边界情况"""
    # 价格为0
    assert calculate_discount(0, "金卡会员") == 0
    
    # 负价格
    with pytest.raises(ValueError):
        calculate_discount(-100, "金卡会员")
    
    # 不存在的会员等级
    assert calculate_discount(100, "不存在") == 100.0  # 无折扣

原则三:覆盖异常情况

测试错误处理是否正确:

python
def test_exception_handling():
    """测试异常处理"""
    with pytest.raises(ValueError):
        divide(10, 0)

原则四:使用参数化减少重复

python
@pytest.mark.parametrize("price,level,expected", [
    (100, "普通会员", 95.0),
    (100, "银卡会员", 90.0),
    (100, "金卡会员", 85.0),
    (100, "钻石会员", 80.0),
])
def test_member_levels(price, level, expected):
    """测试不同会员等级"""
    assert calculate_discount(price, level) == expected

让AI生成高质量测试的技巧

技巧一:提供完整的上下文

❌ 上下文不足
为这个函数写测试
def process(data):
    ...

✅ 上下文充分
这是一个处理用户注册数据的函数,请为它写测试:

def process_registration(data: dict) -> dict:
    """
    处理用户注册
    
    Args:
        data: 包含username, email, password的字典
        
    Returns:
        包含user_id和status的字典
        
    Raises:
        ValueError: 数据格式错误
        DuplicateError: 用户名已存在
    """
    ...

要求:
1. 测试成功注册的情况
2. 测试数据格式错误的情况
3. 测试用户名重复的情况
4. 使用mock模拟数据库操作

技巧二:指定测试框架和风格

请使用pytest框架编写测试,风格要求:
1. 测试类以Test开头
2. 每个测试方法以test_开头
3. 使用pytest.fixture管理测试数据
4. 使用pytest.raises测试异常

技巧三:要求AI解释测试覆盖

请为以下函数编写单元测试,并说明每个测试用例覆盖了什么场景:

<code>
...
</code>

测试复杂场景

场景一:测试依赖数据库的代码

使用mock隔离外部依赖:

python
from unittest.mock import Mock, patch
import pytest

def test_get_user_from_db():
    """测试从数据库获取用户"""
    # Mock数据库连接
    mock_db = Mock()
    mock_db.query.return_value = {"id": 1, "name": "张三"}
    
    with patch('myapp.db.get_connection', return_value=mock_db):
        user = get_user(1)
        assert user["name"] == "张三"

场景二:测试异步代码

python
import pytest
import asyncio

@pytest.mark.asyncio
async def test_async_function():
    """测试异步函数"""
    result = await async_fetch_data()
    assert result is not None

场景三:测试文件操作

python
import pytest
import tempfile
import os

def test_file_processing():
    """测试文件处理"""
    # 使用临时目录
    with tempfile.TemporaryDirectory() as tmpdir:
        # 创建测试文件
        test_file = os.path.join(tmpdir, "test.txt")
        with open(test_file, "w") as f:
            f.write("test content")
        
        # 测试处理函数
        result = process_file(test_file)
        assert result == "processed"

测试覆盖率

让AI帮你分析覆盖率

请分析这个项目的测试覆盖率,找出缺少测试的关键模块:
1. 列出所有模块
2. 识别哪些模块有测试
3. 找出覆盖率最低的关键模块
4. 建议优先补充测试的顺序

提高覆盖率的策略

python
# 运行覆盖率分析
# pytest --cov=myapp --cov-report=html tests/

# 让AI帮你找出未覆盖的代码
# 然后针对性补充测试

小结

AI辅助写测试的核心要点:

要点说明
提供上下文函数的功能、参数、返回值、异常
指定框架pytest、unittest等
要求覆盖面正常情况、边界情况、异常情况
使用参数化减少重复代码

最佳实践

不要一次性让AI写完所有测试。先让AI写核心用例,运行确认通过后,再逐步补充边界情况测试。

下一步

学会了写测试后,让我们继续学习 AI辅助Debug调试,让AI帮你快速定位和修复bug。