#!/usr/bin/env python3

# obligatory
import os
import sys
import unittest

# standard library
import decimal

# third-party
import mariadb


class TestDatabase(unittest.TestCase):

    def test_decimal(self):
        with mariadb.connect(default_file=os.path.expanduser('.my.cnf')) as connection:
            with connection.cursor() as cursor:
                for parameter_type in (int, decimal.Decimal):
                    with self.subTest(parameter_type=parameter_type):
                        with self.subTest(parameter_count=1):
                            with self.subTest(parameter_style='?'):
                                cursor.execute('select ?', [parameter_type(1)])
                                [[value]] = cursor.fetchall()
                                self.assertEqual(value, 1)
                            with self.subTest(parameter_style='%s'):
                                cursor.execute('select %s', [parameter_type(1)])
                                [[value]] = cursor.fetchall()
                                self.assertEqual(value, 1)
                            with self.subTest(parameter_style='%(name)s'):
                                cursor.execute('select %(value)s', dict(value=parameter_type(1)))
                                [[value]] = cursor.fetchall()
                                self.assertEqual(value, 1)
                        with self.subTest(parameter_count=2):
                            with self.subTest(parameter_style='?'):
                                cursor.execute('select ?, ?', [parameter_type(1), 1])
                                [[value, _]] = cursor.fetchall()
                                self.assertEqual(value, 1)
                            with self.subTest(parameter_style='%s'):
                                cursor.execute('select %s, %s', [parameter_type(1), 1])
                                [[value, _]] = cursor.fetchall()
                                self.assertEqual(value, 1)
                            with self.subTest(parameter_style='%(name)s'):
                                cursor.execute('select %(value)s, %(dummy)s', dict(value=parameter_type(1), dummy=1))
                                [[value, _]] = cursor.fetchall()
                                self.assertEqual(value, 1)


if __name__ == '__main__':
    sys.exit(unittest.main())

