如何在 python 中对递归函数进行单元测试?

How can I unit test a recursive functions in python?

我想知道如何对递归函数是否被正确调用进行单元测试。例如这个函数:

def test01(number):
        if(len(number) == 1):
            return 1
        else:
            return 1+test01(number[1:])

递归计算一个数字有多少位(假设数字类型是字符串) 所以,我想测试函数 test01 是否被递归调用。直接这么实现就好了,如果这样实现就不行了:

def test01(number):
    return len(number)

编辑: 出于教育目的,递归方法是强制性的,因此 UnitTest 过程将自动检查编程练习。有没有办法检查该函数是否被多次调用?如果可能的话,我可以进行 2 项测试,一项断言正确的输出,另一项检查是否针对同一输入多次调用该函数。

提前感谢您的帮助

通常,单元测试应该至少检查您的函数是否正常工作,并尝试测试其中的所有代码路径

因此,您的单元测试应该尝试多次走主要路径,然后找到出口路径,实现全覆盖

您可以使用 3rd-party coverage module 查看是否所有代码路径都被占用

pip install coverage

python -m coverage erase       # coverage is additive, so clear out old runs
python -m coverage run -m unittest discover tests/unit_tests
python -m coverage report -m   # report, showing missed lines

根据标签猜测 我假设您想使用 unittest 来测试递归调用。这是此类检查的示例:

from unittest import TestCase
import my_module

class RecursionTest(TestCase):
    def setUp(self):
        self.counter = 0  # counts the number of calls

    def checked_fct(self, fct):  # wrapper function that increases a counter on each call
        def wrapped(*args, **kwargs):
            self.counter += 1
            return fct(*args, **kwargs)

        return wrapped

    def test_recursion(self):
        # replace your function with the checked version
        with mock.patch('my_module.test01',
                        self.checked_fct(my_module.test01)):  # assuming test01 lives in my_module.py
            result = my_module.test01('444')  # call the function
            self.assertEqual(result, 3)  # check for the correct result
            self.assertGreater(self.counter, 1)  # ensure the function has been called more than once

注意: 我使用了 import my_module 而不是 from my_module import test01 这样第一个调用也被模拟了——否则调用次数会太少.

根据您的设置情况,您可以手动添加更多测试,或为每个测试自动生成测试代码,或使用 pytest 参数化,或执行其他操作来自动执行测试。

Curtis Schlak 最近教了我这个策略。

它利用了Abstract Syntax Trees and the inspect module

尽我所能,
肖恩

import unittest
import ast
import inspect
from so import test01


class Test(unittest.TestCase):

    # Check to see if function calls itself recursively
    def test_has_recursive_call(self):

        # Boolean switch
        has_recursive_call = False

        # converts function into a string
        src = inspect.getsource(test01)

        # splits the source code into tokens
        # based on the grammar
        # transformed into an Abstract Syntax Tree
        tree = ast.parse(src)

        # walk tree
        for node in ast.walk(tree):

            # check for function call
            # and if the func called was "test01"
            if (
                type(node) is ast.Call
                and node.func.id == "test01"
            ):

                # flip Boolean switch to true
                has_recursive_call = True

        # assert: has_recursive_call should be true
        self.assertTrue(
            has_recursive_call,
            msg="The function does not "
                "make a recursive call",
        )

        print("\nThe function makes a recursive call")


if __name__ == "__main__":

    unittest.main()