From 4e4f430fae29f564b504239441112478c4f1dde2 Mon Sep 17 00:00:00 2001 From: Sean Stewart Date: Tue, 23 Jan 2024 16:55:08 -0500 Subject: [PATCH] First pass at adding pooling support for asyncio. --- .../lib/mysql/connector/aio/__init__.py | 229 +------ .../lib/mysql/connector/aio/abstracts.py | 199 +++++- .../lib/mysql/connector/aio/connection.py | 52 +- .../lib/mysql/connector/aio/pooling.py | 625 ++++++++++++++++++ .../tests/test_aio_pooling.py | 348 ++++++++++ 5 files changed, 1224 insertions(+), 229 deletions(-) create mode 100644 mysql-connector-python/lib/mysql/connector/aio/pooling.py create mode 100644 mysql-connector-python/tests/test_aio_pooling.py diff --git a/mysql-connector-python/lib/mysql/connector/aio/__init__.py b/mysql-connector-python/lib/mysql/connector/aio/__init__.py index e63924af..878aaf5d 100644 --- a/mysql-connector-python/lib/mysql/connector/aio/__init__.py +++ b/mysql-connector-python/lib/mysql/connector/aio/__init__.py @@ -28,228 +28,19 @@ """MySQL Connector/Python - MySQL driver written in Python.""" -__all__ = ["CMySQLConnection", "MySQLConnection", "connect"] - -import random - -from typing import Any - -from ..constants import DEFAULT_CONFIGURATION -from ..errors import Error, InterfaceError, ProgrammingError -from ..pooling import ERROR_NO_CEXT +__all__ = [ + "CMySQLConnection", + "MySQLConnection", + "MySQLConnectionAbstract", + "MySQLConnectionPool", + "connect", + "PooledMySQLConnection", +] + +from .pooling import connect, MySQLConnectionPool, PooledMySQLConnection from .abstracts import MySQLConnectionAbstract from .connection import MySQLConnection - -try: - import dns.exception - import dns.resolver -except ImportError: - HAVE_DNSPYTHON = False -else: - HAVE_DNSPYTHON = True - - try: from .connection_cext import CMySQLConnection except ImportError: CMySQLConnection = None - - -async def connect(*args: Any, **kwargs: Any) -> MySQLConnectionAbstract: - """Creates or gets a MySQL connection object. - - In its simpliest form, `connect()` will open a connection to a - MySQL server and return a `MySQLConnectionAbstract` subclass - object such as `MySQLConnection` or `CMySQLConnection`. - - When any connection pooling arguments are given, for example `pool_name` - or `pool_size`, a pool is created or a previously one is used to return - a `PooledMySQLConnection`. - - Args: - *args: N/A. - **kwargs: For a complete list of possible arguments, see [1]. If no arguments - are given, it uses the already configured or default values. - - Returns: - A `MySQLConnectionAbstract` subclass instance (such as `MySQLConnection` or - a `CMySQLConnection`) instance. - - Examples: - A connection with the MySQL server can be established using either the - `mysql.connector.connect()` method or a `MySQLConnectionAbstract` subclass: - ``` - >>> from mysql.connector.aio import MySQLConnection, HAVE_CEXT - >>> - >>> cnx1 = await mysql.connector.aio.connect(user='joe', database='test') - >>> cnx2 = MySQLConnection(user='joe', database='test') - >>> await cnx2.connect() - >>> - >>> cnx3 = None - >>> if HAVE_CEXT: - >>> from mysql.connector.aio import CMySQLConnection - >>> cnx3 = CMySQLConnection(user='joe', database='test') - ``` - - References: - [1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html - """ - # DNS SRV - dns_srv = kwargs.pop("dns_srv") if "dns_srv" in kwargs else False - - if not isinstance(dns_srv, bool): - raise InterfaceError("The value of 'dns-srv' must be a boolean") - - if dns_srv: - if not HAVE_DNSPYTHON: - raise InterfaceError( - "MySQL host configuration requested DNS " - "SRV. This requires the Python dnspython " - "module. Please refer to documentation" - ) - if "unix_socket" in kwargs: - raise InterfaceError( - "Using Unix domain sockets with DNS SRV lookup is not allowed" - ) - if "port" in kwargs: - raise InterfaceError( - "Specifying a port number with DNS SRV lookup is not allowed" - ) - if "failover" in kwargs: - raise InterfaceError( - "Specifying multiple hostnames with DNS SRV look up is not allowed" - ) - if "host" not in kwargs: - kwargs["host"] = DEFAULT_CONFIGURATION["host"] - - try: - srv_records = dns.resolver.query(kwargs["host"], "SRV") - except dns.exception.DNSException: - raise InterfaceError( - f"Unable to locate any hosts for '{kwargs['host']}'" - ) from None - - failover = [] - for srv in srv_records: - failover.append( - { - "host": srv.target.to_text(omit_final_dot=True), - "port": srv.port, - "priority": srv.priority, - "weight": srv.weight, - } - ) - - failover.sort(key=lambda x: (x["priority"], -x["weight"])) - kwargs["failover"] = [ - {"host": srv["host"], "port": srv["port"]} for srv in failover - ] - - # Failover - if "failover" in kwargs: - return await _get_failover_connection(**kwargs) - - # Use C Extension by default - use_pure = kwargs.get("use_pure", False) - if "use_pure" in kwargs: - del kwargs["use_pure"] # Remove 'use_pure' from kwargs - if not use_pure and CMySQLConnection is None: - raise ImportError(ERROR_NO_CEXT) - - if CMySQLConnection and not use_pure: - cnx = CMySQLConnection(*args, **kwargs) - else: - cnx = MySQLConnection(*args, **kwargs) - await cnx.connect() - return cnx - - -async def _get_failover_connection(**kwargs: Any) -> MySQLConnectionAbstract: - """Return a MySQL connection and try to failover if needed. - - An InterfaceError is raise when no MySQL is available. ValueError is - raised when the failover server configuration contains an illegal - connection argument. Supported arguments are user, password, host, port, - unix_socket and database. ValueError is also raised when the failover - argument was not provided. - - Returns MySQLConnection instance. - """ - config = kwargs.copy() - try: - failover = config["failover"] - except KeyError: - raise ValueError("failover argument not provided") from None - del config["failover"] - - support_cnx_args = set( - [ - "user", - "password", - "host", - "port", - "unix_socket", - "database", - "pool_name", - "pool_size", - "priority", - ] - ) - - # First check if we can add all use the configuration - priority_count = 0 - for server in failover: - diff = set(server.keys()) - support_cnx_args - if diff: - arg = "s" if len(diff) > 1 else "" - lst = ", ".join(diff) - raise ValueError( - f"Unsupported connection argument {arg} in failover: {lst}" - ) - if hasattr(server, "priority"): - priority_count += 1 - - server["priority"] = server.get("priority", 100) - if server["priority"] < 0 or server["priority"] > 100: - raise InterfaceError( - "Priority value should be in the range of 0 to 100, " - f"got : {server['priority']}" - ) - if not isinstance(server["priority"], int): - raise InterfaceError( - "Priority value should be an integer in the range of 0 to " - f"100, got : {server['priority']}" - ) - - if 0 < priority_count < len(failover): - raise ProgrammingError( - "You must either assign no priority to any " - "of the routers or give a priority for " - "every router" - ) - - server_directory = {} - server_priority_list = [] - for server in sorted(failover, key=lambda x: x["priority"], reverse=True): - if server["priority"] not in server_directory: - server_directory[server["priority"]] = [server] - server_priority_list.append(server["priority"]) - else: - server_directory[server["priority"]].append(server) - - for priority in server_priority_list: - failover_list = server_directory[priority] - for _ in range(len(failover_list)): - last = len(failover_list) - 1 - index = random.randint(0, last) - server = failover_list.pop(index) - new_config = config.copy() - new_config.update(server) - new_config.pop("priority", None) - try: - return await connect(**new_config) - except Error: - # If we failed to connect, we try the next server - pass - - raise InterfaceError("Unable to connect to any of the target hosts") diff --git a/mysql-connector-python/lib/mysql/connector/aio/abstracts.py b/mysql-connector-python/lib/mysql/connector/aio/abstracts.py index 0bdcd047..ab062190 100644 --- a/mysql-connector-python/lib/mysql/connector/aio/abstracts.py +++ b/mysql-connector-python/lib/mysql/connector/aio/abstracts.py @@ -88,7 +88,7 @@ InterfaceError, InternalError, NotSupportedError, - ProgrammingError, + ProgrammingError, OperationalError, ) from ..types import ( BinaryProtocolType, @@ -206,6 +206,7 @@ def __init__( self._host: str = host self._port: int = port self._database: str = database + self._charset_id: int = 45 self._password1: str = password1 self._password2: str = password2 self._password3: str = password3 @@ -234,6 +235,7 @@ def __init__( self._tls_ciphersuites: Optional[List[str]] = [] self._auth_plugin: Optional[str] = auth_plugin self._auth_plugin_class: Optional[str] = None + self._pool_config_version: Optional[Any] = None self._handshake: Optional[HandShakeType] = None self._loop: Optional[asyncio.AbstractEventLoop] = ( loop or asyncio.get_event_loop() @@ -688,6 +690,16 @@ def can_consume_results(self, value: bool) -> None: assert isinstance(value, bool) self._consume_results = value + @property + def pool_config_version(self) -> Any: + """Returns the pool configuration version.""" + return self._pool_config_version + + @pool_config_version.setter + def pool_config_version(self, value: Any) -> None: + """Sets the pool configuration version""" + self._pool_config_version = value + @property def in_transaction(self) -> bool: """MySQL session has started a transaction.""" @@ -1005,12 +1017,6 @@ async def set_charset_collation( await self.cmd_query( f"SET NAMES '{self._charset.name}' COLLATE '{self._charset.collation}'" ) - try: - # Required for C Extension - self.set_character_set_name(self._charset.name) - except AttributeError: - # Not required for pure Python connection - pass if self.converter: self.converter.set_charset(self._charset.name) @@ -1330,6 +1336,139 @@ async def commit(self) -> None: async def rollback(self) -> None: """Rollback current transaction.""" + async def start_transaction( + self, + consistent_snapshot: bool = False, + isolation_level: Optional[str] = None, + readonly: Optional[bool] = None, + ) -> None: + """Starts a transaction. + + This method explicitly starts a transaction sending the + START TRANSACTION statement to the MySQL server. You can optionally + set whether there should be a consistent snapshot, which + isolation level you need or which access mode i.e. READ ONLY or + READ WRITE. + + Args: + consistent_snapshot: If `True`, Connector/Python sends WITH CONSISTENT + SNAPSHOT with the statement. MySQL ignores this for + isolation levels for which that option does not apply. + isolation_level: Permitted values are 'READ UNCOMMITTED', 'READ COMMITTED', + 'REPEATABLE READ', and 'SERIALIZABLE'. If the value is + `None`, no isolation level is sent, so the default level + applies. + readonly: Can be `True` to start the transaction in READ ONLY mode or + `False` to start it in READ WRITE mode. If readonly is omitted, + the server's default access mode is used. + + Raises: + ProgrammingError: When a transaction is already in progress + and when `ValueError` when `isolation_level` + specifies an Unknown level. + + Examples: + For example, to start a transaction with isolation level `SERIALIZABLE`, + you would do the following: + ``` + >>> cnx = mysql.connector.aio.connect(...) + >>> await cnx.start_transaction(isolation_level='SERIALIZABLE') + ``` + """ + if self.in_transaction: + raise ProgrammingError("Transaction already in progress") + + if isolation_level: + level = isolation_level.strip().replace("-", " ").upper() + levels = [ + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ] + + if level not in levels: + raise ValueError(f'Unknown isolation level "{isolation_level}"') + + await self._execute_query(f"SET TRANSACTION ISOLATION LEVEL {level}") + + if readonly is not None: + server_version = self.get_server_version() + if server_version < (5, 6, 5): + raise ValueError( + f"MySQL server version {server_version} does not " + "support this feature" + ) + + if readonly: + access_mode = "READ ONLY" + else: + access_mode = "READ WRITE" + await self._execute_query(f"SET TRANSACTION {access_mode}") + + query = "START TRANSACTION" + if consistent_snapshot: + query += " WITH CONSISTENT SNAPSHOT" + await self.cmd_query(query) + + async def reset_session( + self, + user_variables: Optional[Dict[str, Any]] = None, + session_variables: Optional[Dict[str, Any]] = None, + ) -> None: + """Clears the current active session. + + This method resets the session state, if the MySQL server is 5.7.3 + or later active session will be reset without re-authenticating. + For other server versions session will be reset by re-authenticating. + + It is possible to provide a sequence of variables and their values to + be set after clearing the session. This is possible for both user + defined variables and session variables. + + Args: + user_variables: User variables map. + session_variables: System variables map. + + Raises: + OperationalError: If not connected. + InternalError: If there are unread results and InterfaceError on errors. + + Examples: + ``` + >>> user_variables = {'var1': '1', 'var2': '10'} + >>> session_variables = {'wait_timeout': 100000, 'sql_mode': 'TRADITIONAL'} + >>> await cnx.reset_session(user_variables, session_variables) + ``` + """ + if not await self.is_connected(): + raise OperationalError("MySQL Connection not available") + + try: + await self.cmd_reset_connection() + except (NotSupportedError, NotImplementedError): + if self._compress: + raise NotSupportedError( + "Reset session is not supported with compression for " + "MySQL server version 5.7.2 or earlier" + ) from None + await self.cmd_change_user( + self._user, + self._password, + self._database, + self._charset_id, + ) + + if user_variables or session_variables: + cur = await self.cursor() + if user_variables: + for key, value in user_variables.items(): + await cur.execute(f"SET @`{key}` = {value}") + if session_variables: + for key, value in session_variables.items(): + await cur.execute(f"SET SESSION `{key}` = {value}") + await cur.close() + @abstractmethod async def cmd_reset_connection(self) -> bool: """Resets the session state without re-authenticating. @@ -1392,6 +1531,50 @@ async def cmd_stmt_fetch(self, statement_id: int, rows: int = 1) -> None: statement id and the number of rows to fetch. """ + @abstractmethod + async def cmd_change_user( + self, + username: str = "", + password: str = "", + database: str = "", + charset: int = 45, + password1: str = "", + password2: str = "", + password3: str = "", + oci_config_file: str = "", + oci_config_profile: str = "", + ) -> Optional[Dict[str, Any]]: + """Changes the current logged in user. + + It also causes the specified database to become the default (current) + database. It is also possible to change the character set using the + charset argument. + + Args: + username: New account's username. + password: New account's password. + database: Database to become the default (current) database. + charset: Client charset (see [1]), only the lower 8-bits. + password1: New account's password factor 1 - it's used instead + of `password` if set (higher precedence). + password2: New account's password factor 2. + password3: New account's password factor 3. + oci_config_file: OCI configuration file location (path-like string). + oci_config_profile: OCI configuration profile location (path-like string). + + Returns: + ok_packet: Dictionary containing the OK packet information. + + Examples: + ``` + >>> await cnx.cmd_change_user(username='', password='', database='', charset=33) + ``` + + References: + [1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\ + page_protocol_basic_character_set.html#a_protocol_character_set + """ + @abstractmethod async def cmd_stmt_prepare( self, statement: bytes @@ -1550,7 +1733,7 @@ async def __aiter__(self) -> Iterator[RowType]: """ return self # type: ignore[return-value] - async def __next__(self) -> RowType: + async def __anext__(self) -> RowType: """ Used for iterating over the result set. Calles self.fetchone() to get the next row. diff --git a/mysql-connector-python/lib/mysql/connector/aio/connection.py b/mysql-connector-python/lib/mysql/connector/aio/connection.py index bacb62db..b4b17caa 100644 --- a/mysql-connector-python/lib/mysql/connector/aio/connection.py +++ b/mysql-connector-python/lib/mysql/connector/aio/connection.py @@ -171,7 +171,7 @@ def _add_default_conn_attrs(self) -> None: "_os": platform["version"], } - self._connection_attrs.update((default_conn_attrs)) + self._connection_attrs.update(default_conn_attrs) async def _execute_query(self, query: str) -> ResultType: """Execute a query. @@ -601,6 +601,53 @@ async def is_connected(self) -> bool: return False return True + async def reset_session( + self, + user_variables: Optional[Dict[str, Any]] = None, + session_variables: Optional[Dict[str, Any]] = None, + ) -> None: + """Clears the current active session + + This method resets the session state, if the MySQL server is 5.7.3 + or later active session will be reset without re-authenticating. + For other server versions session will be reset by re-authenticating. + + It is possible to provide a sequence of variables and their values to + be set after clearing the session. This is possible for both user + defined variables and session variables. + This method takes two arguments user_variables and session_variables + which are dictionaries. + + Raises OperationalError if not connected, InternalError if there are + unread results and InterfaceError on errors. + """ + if not await self.is_connected(): + raise OperationalError("MySQL Connection not available.") + + if not await self.cmd_reset_connection(): + try: + await self.cmd_change_user( + self._user, + self._password, + self._database, + self._charset_id, + self._password1, + self._password2, + self._password3, + self._oci_config_file, + self._oci_config_profile, + ) + except ProgrammingError: + await self.reconnect() + + cur = await self.cursor() + if user_variables: + for key, value in user_variables.items(): + await cur.execute(f"SET @`{key}` = %s", (value,)) + if session_variables: + for key, value in session_variables.items(): + await cur.execute(f"SET SESSION `{key}` = %s", (value,)) + async def ping( self, reconnect: bool = False, attempts: int = 1, delay: int = 0 ) -> None: @@ -640,7 +687,7 @@ async def shutdown(self) -> None: try: await self._socket.close_connection() - except Exception: # pylint: disable=broad-exception-caught + except (OSError, socket.error): pass # Getting an exception would mean we are disconnected. async def close(self) -> None: @@ -1331,6 +1378,7 @@ async def cmd_change_user( await self.cmd_init_db(database) self._charset = charsets.get_by_id(charset) + self._charset_id = charset self._charset_name = self._charset.name # return ok_pkt diff --git a/mysql-connector-python/lib/mysql/connector/aio/pooling.py b/mysql-connector-python/lib/mysql/connector/aio/pooling.py new file mode 100644 index 00000000..6c327ff4 --- /dev/null +++ b/mysql-connector-python/lib/mysql/connector/aio/pooling.py @@ -0,0 +1,625 @@ +# Copyright (c) 2013, 2022, Oracle and/or its affiliates. All rights reserved. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License, version 2.0, as +# published by the Free Software Foundation. +# +# This program is also distributed with certain software (including +# but not limited to OpenSSL) that is licensed under separate terms, +# as designated in a particular file or component or in included license +# documentation. The authors of MySQL hereby grant you an +# additional permission to link the program and your derivative works +# with the separately licensed software that they have included with +# MySQL. +# +# Without limiting anything contained in the foregoing, this file, +# which is part of MySQL Connector/Python, is also subject to the +# Universal FOSS Exception, version 1.0, a copy of which can be found at +# http://oss.oracle.com/licenses/universal-foss-exception. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the GNU General Public License, version 2.0, for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + +"""Implementing pooling of connections to MySQL servers.""" +from __future__ import annotations + +import random +import re +import asyncio + +from types import TracebackType +from typing import TYPE_CHECKING, Any, Dict, NoReturn, Optional, Tuple, Type, Union +from uuid import uuid4 + +try: + import dns.exception + import dns.asyncresolver +except ImportError: + HAVE_DNSPYTHON = False +else: + HAVE_DNSPYTHON = True + +from .connection import MySQLConnection +from .. import pooling as _pooling +from ..errors import ( + Error, + InterfaceError, + NotSupportedError, + PoolError, + ProgrammingError, +) + +try: + from .connection_cext import CMySQLConnection +except ImportError: + CMySQLConnection = None + + +if TYPE_CHECKING: + from .abstracts import MySQLConnectionAbstract + CMySQLConnection = MySQLConnectionAbstract + +CNX_POOL_MAXSIZE = 32 +CNX_POOL_MAXNAMESIZE = 64 +CNX_POOL_NAMEREGEX = re.compile(r"[^a-zA-Z0-9._:\-*$#]") +ERROR_NO_CEXT = _pooling.ERROR_NO_CEXT +MYSQL_CNX_CLASS: Union[type, Tuple[type, ...]] = ( + MySQLConnection if CMySQLConnection is None else (MySQLConnection, CMySQLConnection) +) + + +async def _get_failover_connection(**kwargs: Any) -> MySQLConnectionAbstract: + """Return a MySQL connection and try to failover if needed. + + An InterfaceError is raise when no MySQL is available. ValueError is + raised when the failover server configuration contains an illegal + connection argument. Supported arguments are user, password, host, port, + unix_socket and database. ValueError is also raised when the failover + argument was not provided. + + Returns MySQLConnection instance. + """ + config = kwargs.copy() + try: + failover = config["failover"] + except KeyError: + raise ValueError("failover argument not provided") from None + del config["failover"] + + support_cnx_args = { + "user", + "password", + "host", + "port", + "unix_socket", + "database", + "pool_name", + "pool_size", + "priority", + } + + # First check if we can add all use the configuration + priority_count = 0 + for server in failover: + diff = set(server.keys()) - support_cnx_args + if diff: + arg = "s" if len(diff) > 1 else "" + lst = ", ".join(diff) + raise ValueError( + f"Unsupported connection argument {arg} in failover: {lst}" + ) + if hasattr(server, "priority"): + priority_count += 1 + + server["priority"] = server.get("priority", 100) + if server["priority"] < 0 or server["priority"] > 100: + raise InterfaceError( + "Priority value should be in the range of 0 to 100, " + f"got : {server['priority']}" + ) + if not isinstance(server["priority"], int): + raise InterfaceError( + "Priority value should be an integer in the range of 0 to " + f"100, got : {server['priority']}" + ) + + if 0 < priority_count < len(failover): + raise ProgrammingError( + "You must either assign no priority to any " + "of the routers or give a priority for " + "every router" + ) + + server_directory = {} + server_priority_list = [] + for server in sorted(failover, key=lambda x: x["priority"], reverse=True): + if server["priority"] not in server_directory: + server_directory[server["priority"]] = [server] + server_priority_list.append(server["priority"]) + else: + server_directory[server["priority"]].append(server) + + for priority in server_priority_list: + failover_list = server_directory[priority] + for _ in range(len(failover_list)): + last = len(failover_list) - 1 + index = random.randint(0, last) + server = failover_list.pop(index) + new_config = config.copy() + new_config.update(server) + new_config.pop("priority", None) + try: + return await connect(**new_config) + except Error: + # If we failed to connect, we try the next server + pass + + raise InterfaceError("Unable to connect to any of the target hosts") + + +async def connect( + *args: Any, **kwargs: Any +) -> Union[MySQLConnectionAbstract, CMySQLConnection]: + """Creates a MySQL connection object. + + In its simpliest form, `connect()` will open a connection to a + MySQL server and return a `MySQLConnectionAbstract` subclass + object such as `MySQLConnection` or `CMySQLConnection`. + + Args: + *args: N/A. + **kwargs: For a complete list of possible arguments, see [1]. If no arguments + are given, it uses the already configured or default values. + + Returns: + A `MySQLConnectionAbstract` subclass instance (such as `MySQLConnection`). + + Examples: + A connection with the MySQL server can be established using either the + `mysql.connector.aio.connect()` method or a `MySQLConnectionAbstract` subclass: + ``` + >>> from mysql.connector.aio import MySQLConnection + >>> + >>> cnx1 = await mysql.connector.aio.connect(user='joe', database='test') + >>> cnx2 = MySQLConnection(user='joe', database='test') + >>> + ``` + + References: + [1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html + """ + # DNS SRV + dns_srv = kwargs.pop("dns_srv") if "dns_srv" in kwargs else False + + if not isinstance(dns_srv, bool): + raise InterfaceError("The value of 'dns-srv' must be a boolean") + + if dns_srv: + if not HAVE_DNSPYTHON: + raise InterfaceError( + "MySQL host configuration requested DNS " + "SRV. This requires the Python dnspython " + "module. Please refer to documentation" + ) + if "unix_socket" in kwargs: + raise InterfaceError( + "Using Unix domain sockets with DNS SRV lookup is not allowed" + ) + if "port" in kwargs: + raise InterfaceError( + "Specifying a port number with DNS SRV lookup is not allowed" + ) + if "failover" in kwargs: + raise InterfaceError( + "Specifying multiple hostnames with DNS SRV look up is not allowed" + ) + if "host" not in kwargs: + kwargs["host"] = _pooling.DEFAULT_CONFIGURATION["host"] + + try: + srv_records = await dns.asyncresolver.query(kwargs["host"], "SRV") + except dns.exception.DNSException: + raise InterfaceError( + f"Unable to locate any hosts for '{kwargs['host']}'" + ) from None + + failover = [] + for srv in srv_records: + failover.append( + { + "host": srv.target.to_text(omit_final_dot=True), + "port": srv.port, + "priority": srv.priority, + "weight": srv.weight, + } + ) + + failover.sort(key=lambda x: (x["priority"], -x["weight"])) + kwargs["failover"] = [ + {"host": srv["host"], "port": srv["port"]} for srv in failover + ] + + # Failover + if "failover" in kwargs: + return await _get_failover_connection(**kwargs) + + # Option files + if "read_default_file" in kwargs: + kwargs["option_files"] = kwargs["read_default_file"] + kwargs.pop("read_default_file") + + if "option_files" in kwargs: + new_config = _pooling.read_option_files(**kwargs) + return await connect(**new_config) + + # Use C Extension by default + use_pure = kwargs.get("use_pure", False) + if "use_pure" in kwargs: + del kwargs["use_pure"] # Remove 'use_pure' from kwargs + if not use_pure and CMySQLConnection is None: + raise ImportError(ERROR_NO_CEXT) + + if CMySQLConnection and not use_pure: + cnx = CMySQLConnection(*args, **kwargs) + else: + cnx = MySQLConnection(*args, **kwargs) + await cnx.connect() + return cnx + + +class PooledMySQLConnection: + """Class holding a MySQL Connection in a pool + + PooledMySQLConnection is used by MySQLConnectionPool to return an + instance holding a MySQL connection. It works like a MySQLConnection + except for methods like close() and config(). + + The close()-method will add the connection back to the pool rather + than disconnecting from the MySQL server. + + Configuring the connection have to be done through the MySQLConnectionPool + method set_config(). Using config() on pooled connection will raise a + PoolError. + + Attributes: + pool_name (str): Returns the name of the connection pool to which the + connection belongs. + """ + + def __init__(self, pool: MySQLConnectionPool, cnx: MySQLConnectionAbstract) -> None: + """Constructor. + + Args: + pool: A `MySQLConnectionPool` instance. + cnx: A `MySQLConnectionAbstract` subclass instance. + """ + if not isinstance(pool, MySQLConnectionPool): + raise AttributeError("pool should be a MySQLConnectionPool") + if not isinstance(cnx, MYSQL_CNX_CLASS): + raise AttributeError("cnx should be a MySQLConnection") + self._cnx_pool: MySQLConnectionPool = pool + self._cnx: MySQLConnectionAbstract = cnx + + async def __aenter__(self) -> PooledMySQLConnection: + return self + + async def __aexit__( + self, + exc_type: Type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + await self.close() + + def __getattr__(self, attr: Any) -> Any: + """Calls attributes of the MySQLConnection instance""" + return getattr(self._cnx, attr) + + async def close(self) -> None: + """Do not close, but adds connection back to pool. + + For a pooled connection, close() does not actually close it but returns it + to the pool and makes it available for subsequent connection requests. If the + pool configuration parameters are changed, a returned connection is closed + and reopened with the new configuration before being returned from the pool + again in response to a connection request. + """ + cnx = self._cnx + try: + if self._cnx_pool.reset_session: + await cnx.reset_session() + finally: + await self._cnx_pool.add_connection(cnx) + self._cnx = None + + @staticmethod + def config(**kwargs: Any) -> NoReturn: + """Configuration is done through the pool. + + For pooled connections, the `config()` method raises a `PoolError` + exception. Configuration for pooled connections should be done + using the pool object. + """ + raise PoolError( + "Configuration for pooled connections should be done through the " + "pool itself" + ) + + @property + def pool_name(self) -> str: + """Returns the name of the connection pool to which the connection belongs.""" + return self._cnx_pool.pool_name + + +class MySQLConnectionPool: + """Class defining a pool of MySQL connections""" + + def __init__( + self, + pool_size: int = 5, + pool_name: Optional[str] = None, + pool_reset_session: bool = True, + pool_connect_timeout: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Constructor. + + Initialize a MySQL connection pool with a maximum number of + connections set to `pool_size`. The rest of the keywords + arguments, kwargs, are configuration arguments for MySQLConnection + instances. + + Args: + pool_name: The pool name. If this argument is not given, Connector/Python + automatically generates the name, composed from whichever of + the host, port, user, and database connection arguments are + given in kwargs, in that order. + pool_size: The pool size. If this argument is not given, the default is 5. + pool_reset_session: Whether to reset session variables when the connection + is returned to the pool. + **kwargs: Optional additional connection arguments, as described in [1]. + + Examples: + ``` + >>> dbconfig = { + >>> "database": "test", + >>> "user": "joe", + >>> } + >>> cnxpool = mysql.connector.pooling.MySQLConnectionPool(pool_name = "mypool", + >>> pool_size = 3, + >>> **dbconfig) + ``` + + References: + [1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html + """ + self._pool_size: Optional[int] = None + self._pool_name: Optional[str] = None + self._pool_connect_timeout: Optional[int] = pool_connect_timeout + self._reset_session = pool_reset_session + self._set_pool_size(pool_size) + self._set_pool_name(pool_name or _pooling.generate_pool_name(**kwargs)) + self._cnx_config: Dict[str, Any] = {} + self._cnx_queue: asyncio.Queue[MySQLConnection] = asyncio.Queue(self._pool_size) + self._pool_lock: asyncio.Lock = asyncio.Lock() + self._config_version = uuid4() + + async def open(self) -> None: + """Open the connection pool and fill with connections.""" + self._cnx_queue = asyncio.Queue(self._pool_size) + self._pool_lock = asyncio.Lock() + await self._fill_pool() + + async def _fill_pool(self, **kwargs: Any) -> None: + if kwargs: + await self.set_config(**kwargs) + cnt = 0 + while cnt < self._pool_size: + await self.add_connection() + cnt += 1 + + @property + def pool_name(self) -> str: + """Returns the name of the connection pool.""" + return self._pool_name + + @property + def pool_size(self) -> int: + """Returns number of connections managed by the pool.""" + return self._pool_size + + @property + def reset_session(self) -> bool: + """Returns whether to reset session.""" + return self._reset_session + + async def set_config(self, **kwargs: Any) -> None: + """Set the connection configuration for `MySQLConnectionAbstract` subclass instances. + + This method sets the configuration used for creating `MySQLConnectionAbstract` + subclass instances such as `MySQLConnection`. See [1] for valid + connection arguments. + + Args: + **kwargs: Connection arguments - for a complete list of possible + arguments, see [1]. + + Raises: + PoolError: When a connection argument is not valid, missing + or not supported by `MySQLConnectionAbstract`. + + References: + [1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html + """ + if not kwargs: + return + + async with self._pool_lock: + try: + await connect(**kwargs) + self._cnx_config = kwargs + self._config_version = uuid4() + except AttributeError as err: + raise PoolError(f"Connection configuration not valid: {err}") from err + + def _set_pool_size(self, pool_size: int) -> None: + """Set the size of the pool + + This method sets the size of the pool but it will not resize the pool. + + Raises an AttributeError when the pool_size is not valid. Invalid size + is 0, negative or higher than pooling.CNX_POOL_MAXSIZE. + """ + if pool_size <= 0 or pool_size > CNX_POOL_MAXSIZE: + raise AttributeError( + "Pool size should be higher than 0 and lower or equal to " + f"{CNX_POOL_MAXSIZE}" + ) + self._pool_size = pool_size + + def _set_pool_name(self, pool_name: str) -> None: + r"""Set the name of the pool. + + This method checks the validity and sets the name of the pool. + + Raises an AttributeError when pool_name contains illegal characters + ([^a-zA-Z0-9._\-*$#]) or is longer than pooling.CNX_POOL_MAXNAMESIZE. + """ + if CNX_POOL_NAMEREGEX.search(pool_name): + raise AttributeError(f"Pool name '{pool_name}' contains illegal characters") + if len(pool_name) > CNX_POOL_MAXNAMESIZE: + raise AttributeError(f"Pool name '{pool_name}' is too long") + self._pool_name = pool_name + + def _queue_connection(self, cnx: MySQLConnectionAbstract) -> None: + """Put connection back in the queue + + This method is putting a connection back in the queue. It will not + acquire a lock as the methods using _queue_connection() will have it + set. + + Raises `PoolError` on errors. + """ + if not isinstance(cnx, MYSQL_CNX_CLASS): + raise PoolError( + "Connection instance not subclass of MySQLConnectionAbstract" + ) + + try: + self._cnx_queue.put_nowait(cnx) + except asyncio.QueueFull as err: + raise PoolError("Failed adding connection; queue is full") from err + + async def add_connection(self, cnx: Optional[MySQLConnectionAbstract] = None) -> None: + """Adds a connection to the pool. + + This method instantiates a `MySQLConnection` using the configuration + passed when initializing the `MySQLConnectionPool` instance or using + the `set_config()` method. + If cnx is a `MySQLConnection` instance, it will be added to the + queue. + + Args: + cnx: The `MySQLConnectionAbstract` subclass object to be added to + the pool. If this argument is missing (aka `None`), the pool + creates a new connection and adds it. + + Raises: + PoolError: When no configuration is set, when no more + connection can be added (maximum reached) or when the connection + can not be instantiated. + """ + async with self._pool_lock: + if not self._cnx_config: + raise PoolError("Connection configuration not available") + + if self._cnx_queue.full(): + raise PoolError("Failed adding connection; queue is full") + + if not cnx: + cnx = await connect(**self._cnx_config) # type: ignore[assignment] + try: + if ( + self._reset_session + and self._cnx_config["compress"] + and cnx.get_server_version() < (5, 7, 3) + ): + raise NotSupportedError( + "Pool reset session is not supported with " + "compression for MySQL server version 5.7.2 " + "or earlier" + ) + except KeyError: + pass + + cnx.pool_config_version = self._config_version + else: + if not isinstance(cnx, MYSQL_CNX_CLASS): + raise PoolError( + "Connection instance not subclass of MySQLConnectionAbstract" + ) + + self._queue_connection(cnx) + + async def get_connection(self) -> PooledMySQLConnection: + """Gets a connection from the pool. + + This method returns an PooledMySQLConnection instance which + has a reference to the pool that created it, and the next available + MySQL connection. + + When the MySQL connection is not connect, a reconnect is attempted. + + Returns: + A `PooledMySQLConnection` instance. + + Raises: + PoolError: On errors. + """ + try: + cnx = self._cnx_queue.get_nowait() + except asyncio.QueueEmpty as err: + raise PoolError("Failed getting connection; pool exhausted") from err + + if ( + not cnx.is_connected() + or self._config_version != cnx.pool_config_version + ): + await cnx.close() + cnx = await self.add_connection() + + return PooledMySQLConnection(self, cnx) + + async def _remove_connections(self) -> int: + """Close all connections + + This method closes all connections. It returns the number + of connections it closed. + + Used mostly for tests. + + Returns int. + """ + async with self._pool_lock: + cnt = 0 + cnxq = self._cnx_queue + while cnxq.qsize(): + try: + cnx = cnxq.get_nowait() + await cnx.disconnect() + cnt += 1 + except asyncio.QueueEmpty: + return cnt + except PoolError: + raise + except Error: + # Any other error when closing means connection is closed + pass + + return cnt diff --git a/mysql-connector-python/tests/test_aio_pooling.py b/mysql-connector-python/tests/test_aio_pooling.py new file mode 100644 index 00000000..82290fc6 --- /dev/null +++ b/mysql-connector-python/tests/test_aio_pooling.py @@ -0,0 +1,348 @@ +# Copyright (c) 2013, 2022, Oracle and/or its affiliates. All rights reserved. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License, version 2.0, as +# published by the Free Software Foundation. +# +# This program is also distributed with certain software (including +# but not limited to OpenSSL) that is licensed under separate terms, +# as designated in a particular file or component or in included license +# documentation. The authors of MySQL hereby grant you an +# additional permission to link the program and your derivative works +# with the separately licensed software that they have included with +# MySQL. +# +# Without limiting anything contained in the foregoing, this file, +# which is part of MySQL Connector/Python, is also subject to the +# Universal FOSS Exception, version 1.0, a copy of which can be found at +# http://oss.oracle.com/licenses/universal-foss-exception. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the GNU General Public License, version 2.0, for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + +"""Unittests for mysql.connector.pooling +""" + +import unittest +import uuid + +from queue import Queue + +try: + from mysql.connector.aio.connection_cext import CMySQLConnection +except ImportError: + CMySQLConnection = None + +import tests + +from mysql.connector import errors +from mysql.connector.aio import pooling +from mysql.connector.aio.connection import MySQLConnection +from mysql.connector.constants import ClientFlag + +MYSQL_CNX_CLASS = ( + MySQLConnection + if CMySQLConnection is None + else (MySQLConnection, CMySQLConnection) +) + + +class PooledMySQLConnectionTests(tests.MySQLConnectorTests): + + def test___init__(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig) + self.assertRaises(TypeError, pooling.PooledMySQLConnection) + cnx = MySQLConnection(**dbconfig) + pcnx = pooling.PooledMySQLConnection(cnxpool, cnx) + self.assertEqual(cnxpool, pcnx._cnx_pool) + self.assertEqual(cnx, pcnx._cnx) + + self.assertRaises(AttributeError, pooling.PooledMySQLConnection, None, None) + self.assertRaises(AttributeError, pooling.PooledMySQLConnection, cnxpool, None) + + def test___getattr__(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=1, pool_name="test") + cnx = MySQLConnection(**dbconfig) + pcnx = pooling.PooledMySQLConnection(cnxpool, cnx) + + exp_attrs = { + "_connection_timeout": dbconfig["connection_timeout"], + "_database": dbconfig["database"], + "_host": dbconfig["host"], + "_password": dbconfig["password"], + "_port": dbconfig["port"], + "_unix_socket": dbconfig["unix_socket"], + } + for attr, value in exp_attrs.items(): + self.assertEqual( + value, + getattr(pcnx, attr), + "Attribute {0} of reference connection not correct".format(attr), + ) + + self.assertEqual(pcnx.connect, cnx.connect) + + async def test_close(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig) + + cnxpool._original_cnx = None + + def dummy_add_connection(self, cnx=None): + self._original_cnx = cnx + + cnxpool.add_connection = dummy_add_connection.__get__( + cnxpool, pooling.MySQLConnectionPool + ) + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + pcnx = pooling.PooledMySQLConnection(cnxpool, MySQLConnection(**dbconfig)) + + cnx = pcnx._cnx + await pcnx.close() + self.assertEqual(cnx, cnxpool._original_cnx) + + async def test_config(self): + dbconfig = tests.get_mysql_config() + cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig) + cnx = await cnxpool.get_connection() + + self.assertRaises(errors.PoolError, cnx.config, user="spam") + + +class MySQLConnectionPoolTests(tests.MySQLConnectorTests): + + async def test_open(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + + with self.assertRaises(errors.PoolError): + pool = pooling.MySQLConnectionPool() + await pool.open() + + with self.assertRaises(AttributeError): + pool = pooling.MySQLConnectionPool( + pool_name="test", + pool_size=-1, + ) + await pool.open() + + with self.assertRaises(AttributeError): + pool = pooling.MySQLConnectionPool( + pool_name="test", + pool_size=0, + ) + await pool.open() + + with self.assertRaises(AttributeError): + pool = pooling.MySQLConnectionPool( + pool_name="test", + pool_size=(pooling.CNX_POOL_MAXSIZE + 1), + ) + await pool.open() + + cnxpool = pooling.MySQLConnectionPool(pool_name="test") + self.assertEqual(5, cnxpool._pool_size) + self.assertEqual("test", cnxpool._pool_name) + self.assertEqual({}, cnxpool._cnx_config) + self.assertTrue(isinstance(cnxpool._cnx_queue, Queue)) + self.assertTrue(isinstance(cnxpool._config_version, uuid.UUID)) + self.assertTrue(True, cnxpool._reset_session) + + cnxpool = pooling.MySQLConnectionPool(pool_size=10, pool_name="test") + self.assertEqual(10, cnxpool._pool_size) + + cnxpool = pooling.MySQLConnectionPool(pool_size=10, **dbconfig) + self.assertEqual( + dbconfig, + cnxpool._cnx_config, + "Connection configuration not saved correctly", + ) + self.assertEqual(10, cnxpool._cnx_queue.qsize()) + self.assertTrue(isinstance(cnxpool._config_version, uuid.UUID)) + + cnxpool = pooling.MySQLConnectionPool( + pool_size=1, pool_name="test", pool_reset_session=False + ) + self.assertFalse(cnxpool._reset_session) + + def test_pool_name(self): + """Test MySQLConnectionPool.pool_name property""" + pool_name = "ham" + cnxpool = pooling.MySQLConnectionPool(pool_name=pool_name) + self.assertEqual(pool_name, cnxpool.pool_name) + + def test_pool_size(self): + """Test MySQLConnectionPool.pool_size property""" + pool_size = 4 + cnxpool = pooling.MySQLConnectionPool(pool_name="test", pool_size=pool_size) + self.assertEqual(pool_size, cnxpool.pool_size) + + def test_reset_session(self): + """Test MySQLConnectionPool.reset_session property""" + cnxpool = pooling.MySQLConnectionPool( + pool_name="test", pool_reset_session=False + ) + self.assertFalse(cnxpool.reset_session) + cnxpool._reset_session = True + self.assertTrue(cnxpool.reset_session) + + def test__set_pool_size(self): + cnxpool = pooling.MySQLConnectionPool(pool_name="test") + self.assertRaises(AttributeError, cnxpool._set_pool_size, -1) + self.assertRaises(AttributeError, cnxpool._set_pool_size, 0) + self.assertRaises( + AttributeError, + cnxpool._set_pool_size, + pooling.CNX_POOL_MAXSIZE + 1, + ) + + cnxpool._set_pool_size(pooling.CNX_POOL_MAXSIZE - 1) + self.assertEqual(pooling.CNX_POOL_MAXSIZE - 1, cnxpool._pool_size) + + def test__set_pool_name(self): + cnxpool = pooling.MySQLConnectionPool(pool_name="test") + + self.assertRaises(AttributeError, cnxpool._set_pool_name, "pool name") + self.assertRaises(AttributeError, cnxpool._set_pool_name, "pool%%name") + self.assertRaises( + AttributeError, + cnxpool._set_pool_name, + "long_pool_name" * pooling.CNX_POOL_MAXNAMESIZE, + ) + + async def test_add_connection(self): + cnxpool = pooling.MySQLConnectionPool(pool_name="test") + self.assertRaises(errors.PoolError, cnxpool.add_connection) + + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=2, pool_name="test") + await cnxpool.open() + await cnxpool.set_config(**dbconfig) + + await cnxpool.add_connection() + pcnx = pooling.PooledMySQLConnection( + cnxpool, cnxpool._cnx_queue.get_nowait() + ) + self.assertTrue(isinstance(pcnx._cnx, MYSQL_CNX_CLASS)) + self.assertEqual(cnxpool, pcnx._cnx_pool) + self.assertEqual(cnxpool._config_version, pcnx._cnx._pool_config_version) + + cnx = pcnx._cnx + await pcnx.close() + # We should get the same connection back + self.assertEqual(cnx, cnxpool._cnx_queue.get_nowait()) + await cnxpool.add_connection(cnx) + + # reach max connections + await cnxpool.add_connection() + with self.assertRaises(errors.PoolError): + await cnxpool.add_connection() + + # fail connecting + await cnxpool._remove_connections() + cnxpool._cnx_config["port"] = 9999999 + cnxpool._cnx_config["unix_socket"] = "/ham/spam/foobar.socket" + with self.assertRaises(errors.Error): + await cnxpool.add_connection() + + with self.assertRaises(errors.PoolError): + await cnxpool.add_connection(cnx=str) + + async def test_set_config(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_name="test") + await cnxpool.open() + + # No configuration changes + config_version = cnxpool._config_version + await cnxpool.set_config() + self.assertEqual(config_version, cnxpool._config_version) + self.assertEqual({}, cnxpool._cnx_config) + + # Valid configuration changes + config_version = cnxpool._config_version + await cnxpool.set_config(**dbconfig) + self.assertEqual(dbconfig, cnxpool._cnx_config) + self.assertNotEqual(config_version, cnxpool._config_version) + + # Invalid configuration changes + config_version = cnxpool._config_version + wrong_dbconfig = dbconfig.copy() + wrong_dbconfig["spam"] = "ham" + with self.assertRaises(errors.PoolError): + await cnxpool.set_config(**wrong_dbconfig) + self.assertEqual(dbconfig, cnxpool._cnx_config) + self.assertEqual(config_version, cnxpool._config_version) + + async def test_get_connection(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=2, pool_name="test") + await cnxpool.open() + + async with self.assertRaises(errors.PoolError): + await cnxpool.get_connection() + + cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig) + await cnxpool.open() + + # Get connection from pool + pcnx = await cnxpool.get_connection() + self.assertTrue(isinstance(pcnx, pooling.PooledMySQLConnection)) + with self.assertRaises(errors.PoolError): + await cnxpool.get_connection() + self.assertEqual(pcnx._cnx._pool_config_version, cnxpool._config_version) + prev_config_version = pcnx._pool_config_version + prev_thread_id = pcnx.connection_id + await pcnx.close() + + # Change configuration + config_version = cnxpool._config_version + await cnxpool.set_config(autocommit=True) + self.assertNotEqual(config_version, cnxpool._config_version) + + pcnx = await cnxpool.get_connection() + self.assertNotEqual(pcnx._cnx._pool_config_version, prev_config_version) + self.assertNotEqual(prev_thread_id, pcnx.connection_id) + self.assertEqual(1, pcnx.autocommit) + await pcnx.close() + + # Get connection from pool using a context manager + with cnxpool.get_connection() as pcnx: + self.assertTrue(isinstance(pcnx, pooling.PooledMySQLConnection)) + + async def test__remove_connections(self): + dbconfig = tests.get_mysql_config() + if tests.MYSQL_VERSION < (5, 7): + dbconfig["client_flags"] = [-ClientFlag.CONNECT_ARGS] + cnxpool = pooling.MySQLConnectionPool(pool_size=2, pool_name="test", **dbconfig) + await cnxpool.open() + pcnx = await cnxpool.get_connection() + self.assertEqual(1, await cnxpool._remove_connections()) + await pcnx.close() + self.assertEqual(1, await cnxpool._remove_connections()) + self.assertEqual(0, await cnxpool._remove_connections()) + + with self.assertRaises(errors.PoolError): + await cnxpool.get_connection()