Python使用pytest_mock在函数中模拟多个查询

我正在为其中具有多个sql查询的函数编写单元测试用例。我正在使用psycopg2模块并尝试模拟cursor

app.py

import psycopg2

def my_function():
    # all connection related code goes here ...

    query = "SELECT name,phone FROM customer WHERE name='shanky'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    customer_response = []
    for row in cursor.fetchall():
        customer_response.append(dict(zip(columns,row)))

    query = "SELECT name,id FROM product WHERE name='soap'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    product_response = []
    for row in cursor.fetchall():
        product_response.append(dict(zip(columns,row)))

    return product_response

test.py

from pytest_mock import mocker
import psycopg2

def test_my_function(mocker):
    from my_module import app
    mocker.patch('psycopg2.connect')

    #first query
    mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_one.description = [['name'],['phone']]
    mocked_cursor_one.fetchall.return_value = [('shanky','347539593')]
    mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"

    #second query
    mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_two.description = [['name'],['id']]
    mocked_cursor_two.fetchall.return_value = [('nirma',12313)]
    mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"

    ret = app.my_function()
    assert ret == {'name' : 'nirma','id' : 12313}

但是模拟程序总是采用最后一个模拟对象(第二个查询)。我已经尝试了多次破解,但是没有成功。如何在一个函数中模拟多个查询并成功通过单元测试用例?是否可以以这种方式编写单元测试用例,或者我需要将查询拆分为不同的功能?

abcd123abc123 回答:Python使用pytest_mock在函数中模拟多个查询

尝试side_effect的{​​{1}}参数:

mocker.patch

根据the docsfrom unittest.mock import MagicMock from pytest_mock import mocker import psycopg2 def test_my_function(mocker): from my_module import app mocker.patch('psycopg2.connect',side_effect=[MagicMock(),MagicMock()]) #first query mocked_cursor_one = psycopg2.connect().cursor.return_value # note that we actually call psyocpg2.connect -- it's important mocked_cursor_one.description = [['name'],['phone']] mocked_cursor_one.fetchall.return_value = [('shanky','347539593')] mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'" #second query mocked_cursor_two = psycopg2.connect().cursor.return_value mocked_cursor_two.description = [['name'],['id']] mocked_cursor_two.fetchall.return_value = [('nirma',12313)] mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'" assert mocked_cursor_one is not mocked_cursor_two # show that they are different ret = app.my_function() assert ret == {'name' : 'nirma','id' : 12313} 允许您在每次调用修补对象时更改返回值:

  

如果传递一个Iterable,它将用于检索一个迭代器,该迭代器必须在每次调用时产生一个值。该值可以是要引发的异常实例,也可以是从对模拟的调用中返回的值

,

正如我在前面的评论中提到的,使单元测试可移植的最佳方法是开发一个完整的模拟数据库行为的方法。 我已经在MySQL上做到了,但是对于所有数据库它几乎都一样。

首先,我喜欢在正在使用的程序包上使用包装器类,它有助于在一个地方快速更改数据库,而不是在代码中各处更改数据库。

这是我用作包装器的信物:

现在,您需要模拟该MySQL类:

# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------


import logging


logger = logging.getLogger(__name__)


class Database:
    """Database Metaclass"""

    def __init__(self,connect_func,**kwargs):
        self.connection = connect_func(**kwargs)

    def execute(self,statement,fetchall=True):
        """Execute a statement.

        Execute the statement passed as arugment.

        Args:
            statement (str): SQL Query or Command to execute.

        Returns:
            set: List of returned objects by the cursor.
        """
        cursor = self.connection.cursor()
        logger.debug(f"Executing: {statement}")
        cursor.execute(statement)
        if fetchall:
            return cursor.fetchall()
        else:
            return cursor.fetchone()

    def __del__(self):
        """Close connection on object deletion."""
        self.connection.close()

还有mysql模块:

# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------


import logging
import mysql.connector

from . import _database


logger = logging.getLogger(__name__)


class MySQL(_database.Database):
    """Snowflake Database Class Wrapper.

    Attributes:
        connection (obj): Object returned from mysql.connector.connect
    """

    def __init__(self,autocommit=True,**kwargs):
        super().__init__(connect_func=mysql.connector.connect,**kwargs)
        self.connection.autocommit = autocommit

实例化为:db = MySQL(user='...',password='...',...)

这是数据文件:

# database_mock_data.json
{
    "customer": {
        "name": [
            "shanky","nirma"
        ],"phone": [
            123123123,232342342
        ]
    },"product": {
        "name": [
            "shanky","id": [
            1,2
        ]
    }
}

mocks.py

# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'


class MockDatabase(MySQL):
    """
    """
    def __init__(self,**kwargs):
        self.connection = MockConnection()


class MockConnection:
    """
    Mock the connection object by returning a mock cursor.
    """
    @staticmethod
    def cursor():
        return MockCursor()


class MockCursor:
    """
    The Mocked Cursor

    A call to execute() will initiate the read on the json data file and will set
    the description object (containing the column names usually).

    You could implement an update function like `_json_sql_update()`
    """
    def __init__(self):
        self.description = []
        self.__result = None

    def execute(self,statement):
        data = _read_json_file(_MOCK_DATA_PATH)
        if statement.upper().startswith('SELECT'):
            self.__result,self.description = _json_sql_select(data,statement)

    def fetchall(self):
        return self.__result

    def fetchone(self):
        return self.__result[0]


def _json_sql_select(data,query):
    """
    Takes a dictionary and returns the values from a sql query.
    NOTE: It does not work with other where clauses than '='.
          Also,note that a where statement is expected.
    :param (dict) data: Dictionary with the following structure:
                        {
                            'tablename': {
                                'column_name_1': ['value1','value2],'column_name_2': ['value1',...
                            },...
                        }
    :param (str) query: An update sql query as:
                        `update TABLENAME set column_name_1='value'
                        where column_name_2='value1'`
    :return: List of list of values and header description
    """
    try:
        match = (re.search("select(.*)from(.*)where(.*)[;]?",query,re.IGNORECASE | re.DOTALL).groups())
    except AttributeError:
        print("Select Query pattern mismatch... {}".format(query))
        raise

    # Parse values from the select query
    tablename = match[1].strip().upper()

    columns = [col.strip().upper() for col in match[0].split(",")]
    if columns == ['*']:
        columns = data[tablename].keys()

    where = [cmd.upper().strip().replace(' ','')
             for cmd in match[2].split('and')]

    # Select values
    selected_values = []
    nb_lines = len(list(data[tablename].values())[0])
    for i in range(nb_lines):
        is_match = True
        for condition in where:
            key_condition,value_condition = (_clean_string(condition)
                                              .split('='))
            if data[tablename][key_condition][i].upper() != value_condition:
                # Set flag to yes
                is_match = False
        if is_match:
            sub_list = []
            for column in columns:
                sub_list.append(data[tablename][column][i])
            selected_values.append(sub_list)

    # Usual descriptor has nested list
    description = zip(columns,['...'] * len(columns))

    return selected_values,description


def _read_json_file(file_path):
    with open(file_path,'r') as f_in:
        data = json.load(f_in)
    return data

然后将您的测试保存在test_module_yourfunction.py

import pytest

def my_function(db,query):
    # Code goes here

@pytest.fixture
def db_connection():
    return MockDatabase()


@pytest.mark.parametrize(
    ("query","expected"),[
        ("SELECT name,phone FROM customer WHERE name='shanky'",{'name' : 'nirma','id' : 12313}),("<second query goes here>","<second result goes here>")
    ]
)
def test_my_function(db_connection,expected):
    assert my_function(db_connection,query) == expected

现在很抱歉,如果您无法复制/粘贴此代码并使之正常工作,但您会感到:)只是在寻求帮助

,

在深入研究文档之后,我可以借助unittest模拟装饰器和@Pavel Vergeev建议的side_effect来实现这一目标。我能够编写单元测试。足以测试功能的案例。

from unittest import mock
from my_module import app

@mock.patch('psycopg2.connect')
def test_my_function(mocked_db):

    mocked_cursor = mocked_db.return_value.cursor.return_value

    description_mock = mock.PropertyMock()
    type(mocked_cursor).description = description_mock

    fetchall_return_one = [('shanky','347539593')]

    fetchall_return_two = [('nirma',12313)]

    descriptions = [
        [['name'],['phone']],[['name'],['id']]
    ]

    mocked_cursor.fetchall.side_effect = [fetchall_return_one,fetchall_return_two]

    description_mock.side_effect = descriptions

    ret = app.my_function()

    # assert whether called with mocked side effect objects
    mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)

    # assert db query count is 2
    assert mocked_db.return_value.cursor.return_value.execute.call_count == 2

    # first query
    query1 = """
            SELECT name,phone FROM customer WHERE name='shanky'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1

    # second query
    query2 = """
            SELECT name,id FROM product WHERE name='soap'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2

    # assert the data of response
    assert ret == {'name' : 'nirma','id' : 12313}

此外,如果查询中有动态参数,也可以通过以下方法进行断言。

assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)

因此,当执行第一个查询时,可以获取并声明cursor.execute(query,(parameter_name,))处的call_args_list[0][0][0],并在call_args_list[0][0][1]处获取第一参数parameter_name。类似地,增加索引,可以获取并声明所有其他参数和不同的查询。

本文链接:https://www.f2er.com/3096608.html

大家都在问