尝试side_effect
的{{1}}参数:
mocker.patch
根据the docs,from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect',side_effect=[MagicMock(),MagicMock()])
#first query
mocked_cursor_one = psycopg2.connect().cursor.return_value # note that we actually call psyocpg2.connect -- it's important
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky','347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect().cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma',12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
assert mocked_cursor_one is not mocked_cursor_two # show that they are different
ret = app.my_function()
assert ret == {'name' : 'nirma','id' : 12313}
允许您在每次调用修补对象时更改返回值:
如果传递一个Iterable,它将用于检索一个迭代器,该迭代器必须在每次调用时产生一个值。该值可以是要引发的异常实例,也可以是从对模拟的调用中返回的值
,
正如我在前面的评论中提到的,使单元测试可移植的最佳方法是开发一个完整的模拟数据库行为的方法。
我已经在MySQL上做到了,但是对于所有数据库它几乎都一样。
首先,我喜欢在正在使用的程序包上使用包装器类,它有助于在一个地方快速更改数据库,而不是在代码中各处更改数据库。
这是我用作包装器的信物:
现在,您需要模拟该MySQL类:
# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
class Database:
"""Database Metaclass"""
def __init__(self,connect_func,**kwargs):
self.connection = connect_func(**kwargs)
def execute(self,statement,fetchall=True):
"""Execute a statement.
Execute the statement passed as arugment.
Args:
statement (str): SQL Query or Command to execute.
Returns:
set: List of returned objects by the cursor.
"""
cursor = self.connection.cursor()
logger.debug(f"Executing: {statement}")
cursor.execute(statement)
if fetchall:
return cursor.fetchall()
else:
return cursor.fetchone()
def __del__(self):
"""Close connection on object deletion."""
self.connection.close()
还有mysql模块:
# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------
import logging
import mysql.connector
from . import _database
logger = logging.getLogger(__name__)
class MySQL(_database.Database):
"""Snowflake Database Class Wrapper.
Attributes:
connection (obj): Object returned from mysql.connector.connect
"""
def __init__(self,autocommit=True,**kwargs):
super().__init__(connect_func=mysql.connector.connect,**kwargs)
self.connection.autocommit = autocommit
实例化为:db = MySQL(user='...',password='...',...)
这是数据文件:
# database_mock_data.json
{
"customer": {
"name": [
"shanky","nirma"
],"phone": [
123123123,232342342
]
},"product": {
"name": [
"shanky","id": [
1,2
]
}
}
mocks.py
# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'
class MockDatabase(MySQL):
"""
"""
def __init__(self,**kwargs):
self.connection = MockConnection()
class MockConnection:
"""
Mock the connection object by returning a mock cursor.
"""
@staticmethod
def cursor():
return MockCursor()
class MockCursor:
"""
The Mocked Cursor
A call to execute() will initiate the read on the json data file and will set
the description object (containing the column names usually).
You could implement an update function like `_json_sql_update()`
"""
def __init__(self):
self.description = []
self.__result = None
def execute(self,statement):
data = _read_json_file(_MOCK_DATA_PATH)
if statement.upper().startswith('SELECT'):
self.__result,self.description = _json_sql_select(data,statement)
def fetchall(self):
return self.__result
def fetchone(self):
return self.__result[0]
def _json_sql_select(data,query):
"""
Takes a dictionary and returns the values from a sql query.
NOTE: It does not work with other where clauses than '='.
Also,note that a where statement is expected.
:param (dict) data: Dictionary with the following structure:
{
'tablename': {
'column_name_1': ['value1','value2],'column_name_2': ['value1',...
},...
}
:param (str) query: An update sql query as:
`update TABLENAME set column_name_1='value'
where column_name_2='value1'`
:return: List of list of values and header description
"""
try:
match = (re.search("select(.*)from(.*)where(.*)[;]?",query,re.IGNORECASE | re.DOTALL).groups())
except AttributeError:
print("Select Query pattern mismatch... {}".format(query))
raise
# Parse values from the select query
tablename = match[1].strip().upper()
columns = [col.strip().upper() for col in match[0].split(",")]
if columns == ['*']:
columns = data[tablename].keys()
where = [cmd.upper().strip().replace(' ','')
for cmd in match[2].split('and')]
# Select values
selected_values = []
nb_lines = len(list(data[tablename].values())[0])
for i in range(nb_lines):
is_match = True
for condition in where:
key_condition,value_condition = (_clean_string(condition)
.split('='))
if data[tablename][key_condition][i].upper() != value_condition:
# Set flag to yes
is_match = False
if is_match:
sub_list = []
for column in columns:
sub_list.append(data[tablename][column][i])
selected_values.append(sub_list)
# Usual descriptor has nested list
description = zip(columns,['...'] * len(columns))
return selected_values,description
def _read_json_file(file_path):
with open(file_path,'r') as f_in:
data = json.load(f_in)
return data
然后将您的测试保存在test_module_yourfunction.py
中
import pytest
def my_function(db,query):
# Code goes here
@pytest.fixture
def db_connection():
return MockDatabase()
@pytest.mark.parametrize(
("query","expected"),[
("SELECT name,phone FROM customer WHERE name='shanky'",{'name' : 'nirma','id' : 12313}),("<second query goes here>","<second result goes here>")
]
)
def test_my_function(db_connection,expected):
assert my_function(db_connection,query) == expected
现在很抱歉,如果您无法复制/粘贴此代码并使之正常工作,但您会感到:)只是在寻求帮助
,
在深入研究文档之后,我可以借助unittest
模拟装饰器和@Pavel Vergeev建议的side_effect
来实现这一目标。我能够编写单元测试。足以测试功能的案例。
from unittest import mock
from my_module import app
@mock.patch('psycopg2.connect')
def test_my_function(mocked_db):
mocked_cursor = mocked_db.return_value.cursor.return_value
description_mock = mock.PropertyMock()
type(mocked_cursor).description = description_mock
fetchall_return_one = [('shanky','347539593')]
fetchall_return_two = [('nirma',12313)]
descriptions = [
[['name'],['phone']],[['name'],['id']]
]
mocked_cursor.fetchall.side_effect = [fetchall_return_one,fetchall_return_two]
description_mock.side_effect = descriptions
ret = app.my_function()
# assert whether called with mocked side effect objects
mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)
# assert db query count is 2
assert mocked_db.return_value.cursor.return_value.execute.call_count == 2
# first query
query1 = """
SELECT name,phone FROM customer WHERE name='shanky'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1
# second query
query2 = """
SELECT name,id FROM product WHERE name='soap'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2
# assert the data of response
assert ret == {'name' : 'nirma','id' : 12313}
此外,如果查询中有动态参数,也可以通过以下方法进行断言。
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)
因此,当执行第一个查询时,可以获取并声明cursor.execute(query,(parameter_name,))
处的call_args_list[0][0][0]
,并在call_args_list[0][0][1]
处获取第一参数parameter_name
。类似地,增加索引,可以获取并声明所有其他参数和不同的查询。
本文链接:https://www.f2er.com/3096608.html