import org.json.JSONObject;

import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.sql.*;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;

/**
 * A utility class to test the behavior of prepared statements
 * in MariaDB Connector/J with different connection parameters.
 *
 * This code requires a MariaDB database and the corresponding
 * Connector/J JAR file in the classpath. The org.json.jar library is also required.
 *
 * To test with a specific connector version (e.g., 2.7.5), ensure
 * only that version's JAR is in the classpath. Repeat for 3.4.1.
 *
 * Configuration is read from a file named config.properties.
 * Test results can be optionally logged to a file.
 */
public class MariaDBPreparedStmtTest {

    // --- Configuration Variables (Loaded from file) ---
    private static String DB_URL;
    private static String DB_USER;
    private static String DB_PASSWORD;
    private static int NUM_THREADS;
    private static int STATEMENTS_PER_THREAD;
    private static String[] connectorVersions;
    private static int TEST_DURATION_SECONDS;
    private static String[] monitorStatusVariables;
    private static Map<String, List<String>> testConfigOptions;
    private static List<Map<String, String>> testMatrix;
    private static boolean LOG_TO_FILE;
    private static boolean LOG_FILENAME_TIMESTAMP;
    private static String[] testTypes;
    private static boolean USE_PREPARED_STATEMENTS;
    private static boolean USE_TRANSACTIONS;
    private static int TRANSACTION_SIZE;
    private static int THREAD_SLEEP_TIME;
    private static int MONITOR_SLEEP_TIME;

    private static final String LOG_FILE_PREFIX = "test_results";
    private static final String LOG_FILE_EXTENSION = ".log";

    // --- SQL Queries ---
    private static final String CREATE_TABLE_SQL = "CREATE TABLE IF NOT EXISTS test_table (id INT PRIMARY KEY, name VARCHAR(50))";
    private static final String DROP_TABLE_SQL = "DROP TABLE IF EXISTS test_table";
    private static final String INSERT_SQL = "INSERT INTO test_table (id, name) VALUES (?, ?)";
    private static final String SELECT_SQL = "SELECT name FROM test_table WHERE id = ?";
    private static final String UPDATE_SQL = "UPDATE test_table SET name = ? WHERE id = ?";
    private static final String DELETE_SQL = "DELETE FROM test_table WHERE id = ?";

    /**
     * The main method to orchestrate the test cases. It iterates through
     * different connector versions and all combinations of the specified
     * connection parameters.
     * @param args Command line arguments (not used)
     */
    public static void main(String[] args) {
        // Load configuration from file.
        try {
            loadConfig();
        } catch (IOException e) {
            System.err.println("Error loading configuration file: " + e.getMessage());
            return;
        }

        // Print all configuration variables as a header.
        printConfigHeader();

        // Generate the test matrix dynamically from the config file.
        generateTestMatrix();

        System.out.println("Starting multi-threaded prepared statement test...");
        if (LOG_TO_FILE) {
            System.out.println("Results will be logged to a file.");
        } else {
            System.out.println("File logging is disabled. Results will only be printed to the console.");
        }
        System.out.println("-----------------------------------------------------------------");

        String logFileName = LOG_FILE_PREFIX + (LOG_FILENAME_TIMESTAMP ? "_" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmmss")) : "") + LOG_FILE_EXTENSION;
        
        try (PrintWriter logWriter = LOG_TO_FILE ? new PrintWriter(new FileWriter(logFileName)) : null) {
            if (LOG_TO_FILE) {
                logWriter.println("# --- TEST CONFIGURATION START ---");
                logWriter.printf("# Database URL: %s\n", DB_URL);
                logWriter.printf("# Database User: %s\n", DB_USER);
                logWriter.printf("# Threads: %d\n", NUM_THREADS);
                logWriter.printf("# Statements Per Thread: %d\n", STATEMENTS_PER_THREAD);
                logWriter.printf("# Test Duration: %d seconds\n", TEST_DURATION_SECONDS);
                logWriter.printf("# Monitored Variables: %s\n", Arrays.stream(monitorStatusVariables).collect(Collectors.joining(", ")));
                logWriter.printf("# Connector Versions: %s\n", Arrays.stream(connectorVersions).collect(Collectors.joining(", ")));
                logWriter.printf("# Test Types: %s\n", Arrays.stream(testTypes).collect(Collectors.joining(", ")));
                logWriter.printf("# Use Prepared Statements: %b\n", USE_PREPARED_STATEMENTS);
                logWriter.printf("# Use Transactions: %b\n", USE_TRANSACTIONS);
                logWriter.printf("# Transaction Size: %d\n", TRANSACTION_SIZE);
                logWriter.printf("# Thread Sleep Time: %dms\n", THREAD_SLEEP_TIME);
                logWriter.printf("# Monitor Sleep Time: %dms\n", MONITOR_SLEEP_TIME);
                logWriter.println("# --- TEST CONFIGURATION END ---");
            }

            int totalTests = connectorVersions.length * testMatrix.size();
            int currentTestNumber = 1;

            for (String version : connectorVersions) {
                for (Map<String, String> combination : testMatrix) {
                    // Construct the properties for the connection.
                    Properties props = new Properties();
                    props.setProperty("user", DB_USER);
                    props.setProperty("password", DB_PASSWORD);
                    combination.forEach(props::setProperty);

                    if (LOG_TO_FILE) {
                        logWriter.printf(
                            "\n# Test Case Start\n" +
                            "connector_version=%s\n" +
                            "test_parameters=%s\n",
                            version,
                            combination.entrySet().stream()
                                .map(e -> e.getKey() + "=" + e.getValue())
                                .collect(Collectors.joining(","))
                        );
                    }

                    System.out.printf(
                        "\nTest [%d/%d] -> %s (Version: %s)\n",
                        currentTestNumber,
                        totalTests,
                        combination.entrySet().stream()
                            .map(e -> e.getKey() + "=" + e.getValue())
                            .collect(Collectors.joining(", ")),
                        version
                    );

                    runMultiThreadedTest(props, Optional.ofNullable(logWriter));
                    if (LOG_TO_FILE) {
                        logWriter.println("# Test Case End");
                    }
                    currentTestNumber++;
                }
            }
        } catch (IOException e) {
            System.err.println("Error writing to log file: " + e.getMessage());
        }

        System.out.println("\n-----------------------------------------------------------------");
        System.out.println("Multi-threaded prepared statement test finished.");
    }

    /**
     * Executes the multi-threaded test for a given set of connection properties.
     * @param props The connection properties for the test.
     * @param logWriter The PrintWriter to write test results to.
     */
    private static void runMultiThreadedTest(Properties props, Optional<PrintWriter> logWriter) {
        // Create a fixed-size thread pool.
        ExecutorService executor = Executors.newFixedThreadPool(NUM_THREADS);
        CountDownLatch latch = new CountDownLatch(1);

        // A separate thread to monitor the prepared_stmt_count.
        Thread monitorThread = new Thread(() -> {
            try {
                long startTime = System.currentTimeMillis();
                while (latch.getCount() > 0) {
                    Map<String, Double> statusData = new HashMap<>();
                    statusData.put("timestamp", (System.currentTimeMillis() - startTime) / 1000.0);
                    for (String var : monitorStatusVariables) {
                        statusData.put(var, (double) getGlobalStatusVariable(var));
                    }
                    
                    if (logWriter.isPresent()) {
                         logWriter.get().printf("data=%s\n", new JSONObject(statusData).toString());
                    }
                    System.out.print(".");
                    Thread.sleep(MONITOR_SLEEP_TIME);
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        });

        // Use a single connection to perform table setup and teardown.
        try (Connection conn = DriverManager.getConnection(DB_URL, props)) {
            try (Statement stmt = conn.createStatement()) {
                System.out.println("  Main thread: Dropping and creating test table...");
                stmt.execute(DROP_TABLE_SQL);
                stmt.execute(CREATE_TABLE_SQL);
            }

            monitorThread.start();
            
            // Submit tasks to the thread pool with a unique index for each thread.
            for (int i = 0; i < NUM_THREADS; i++) {
                final int threadIndex = i;
                executor.submit(() -> {
                    try (Connection workerConn = DriverManager.getConnection(DB_URL, props)) {
                        workerConn.setAutoCommit(!USE_TRANSACTIONS);
                        
                        for (int j = 0; j < STATEMENTS_PER_THREAD; j++) {
                            int uniqueId = (threadIndex * STATEMENTS_PER_THREAD) + j;
                            for (String type : testTypes) {
                                switch (type) {
                                    case "insert":
                                        executeInsert(workerConn, uniqueId, "name_" + uniqueId);
                                        break;
                                    case "select":
                                        executeSelect(workerConn, uniqueId);
                                        break;
                                    case "update":
                                        executeUpdate(workerConn, uniqueId, "updated_name_" + uniqueId);
                                        break;
                                    case "delete":
                                        executeDelete(workerConn, uniqueId);
                                        break;
                                }
                            }
                            if (USE_TRANSACTIONS && (j + 1) % TRANSACTION_SIZE == 0) {
                                workerConn.commit();
                            }
                        }
                        if (USE_TRANSACTIONS) {
                            workerConn.commit();
                        }
                    } catch (SQLException e) {
                        System.err.println("  Thread " + Thread.currentThread().getId() + " failed: " + e.getMessage());
                    }
                    
                    try {
                        Thread.sleep(THREAD_SLEEP_TIME);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                });
            }
            
            executor.shutdown();
            try {
                if (!executor.awaitTermination(TEST_DURATION_SECONDS, TimeUnit.SECONDS)) {
                    System.out.println("  Test timed out. Forcing shutdown.");
                    executor.shutdownNow();
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        } catch (SQLException e) {
            System.err.println("  Initial setup failed: " + e.getMessage());
        } finally {
            latch.countDown();
            try {
                monitorThread.join();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            try (Connection cleanupConn = DriverManager.getConnection(DB_URL, props);
                 Statement cleanupStmt = cleanupConn.createStatement()) {
                System.out.println("  Main thread: Cleaning up test table...");
                cleanupStmt.execute(DROP_TABLE_SQL);
            } catch (SQLException e) {
                System.err.println("  Cleanup failed: " + e.getMessage());
            }
        }
    }

    private static void executeInsert(Connection conn, int id, String name) throws SQLException {
        if (USE_PREPARED_STATEMENTS) {
            try (PreparedStatement ps = conn.prepareStatement(INSERT_SQL)) {
                ps.setInt(1, id);
                ps.setString(2, name);
                ps.executeUpdate();
            }
        } else {
            try (Statement stmt = conn.createStatement()) {
                String sql = String.format("INSERT INTO test_table (id, name) VALUES (%d, '%s')", id, name);
                stmt.executeUpdate(sql);
            }
        }
    }

    private static void executeSelect(Connection conn, int id) throws SQLException {
        if (USE_PREPARED_STATEMENTS) {
            try (PreparedStatement ps = conn.prepareStatement(SELECT_SQL)) {
                ps.setInt(1, id);
                ps.executeQuery();
            }
        } else {
            try (Statement stmt = conn.createStatement()) {
                String sql = String.format("SELECT name FROM test_table WHERE id = %d", id);
                stmt.executeQuery(sql);
            }
        }
    }

    private static void executeUpdate(Connection conn, int id, String name) throws SQLException {
        if (USE_PREPARED_STATEMENTS) {
            try (PreparedStatement ps = conn.prepareStatement(UPDATE_SQL)) {
                ps.setString(1, name);
                ps.setInt(2, id);
                ps.executeUpdate();
            }
        } else {
            try (Statement stmt = conn.createStatement()) {
                String sql = String.format("UPDATE test_table SET name = '%s' WHERE id = %d", name, id);
                stmt.executeUpdate(sql);
            }
        }
    }

    private static void executeDelete(Connection conn, int id) throws SQLException {
        if (USE_PREPARED_STATEMENTS) {
            try (PreparedStatement ps = conn.prepareStatement(DELETE_SQL)) {
                ps.setInt(1, id);
                ps.executeUpdate();
            }
        } else {
            try (Statement stmt = conn.createStatement()) {
                String sql = String.format("DELETE FROM test_table WHERE id = %d", id);
                stmt.executeUpdate(sql);
            }
        }
    }

    /**
     * Loads configuration variables from a config.properties file.
     */
    private static void loadConfig() throws IOException {
        Properties prop = new Properties();
        try (FileInputStream fis = new FileInputStream("config.properties")) {
            prop.load(fis);
        }

        DB_URL = prop.getProperty("db.url");
        DB_USER = prop.getProperty("db.user");
        DB_PASSWORD = prop.getProperty("db.password");
        NUM_THREADS = Integer.parseInt(prop.getProperty("test.threads"));
        STATEMENTS_PER_THREAD = Integer.parseInt(prop.getProperty("test.statements.per.thread"));
        connectorVersions = prop.getProperty("connector.versions").split(",");
        TEST_DURATION_SECONDS = Integer.parseInt(prop.getProperty("test.duration.seconds"));
        monitorStatusVariables = Arrays.stream(prop.getProperty("monitor.status.variables").split(","))
                                       .map(String::trim)
                                       .toArray(String[]::new);
        LOG_TO_FILE = Boolean.parseBoolean(prop.getProperty("log.to.file"));
        LOG_FILENAME_TIMESTAMP = Boolean.parseBoolean(prop.getProperty("log.filename.timestamp"));
        testTypes = prop.getProperty("test.types", "select,insert").split(",");
        USE_PREPARED_STATEMENTS = Boolean.parseBoolean(prop.getProperty("use.prepared.statements"));
        USE_TRANSACTIONS = Boolean.parseBoolean(prop.getProperty("use.transactions"));
        TRANSACTION_SIZE = Integer.parseInt(prop.getProperty("transaction.size", "10"));
        THREAD_SLEEP_TIME = Integer.parseInt(prop.getProperty("thread.sleep.time", "50"));
        MONITOR_SLEEP_TIME = Integer.parseInt(prop.getProperty("monitor.sleep.time", "1000"));

        testConfigOptions = parseTestConfigOptions(prop.getProperty("test.config.options"));
    }

    /**
     * Parses the test configuration options from a string.
     * E.g., "key1=[val1,val2],key2=[val3,val4]"
     * @param optionsString The raw string from the config file.
     * @return A map of option names to a list of their values.
     */
    private static Map<String, List<String>> parseTestConfigOptions(String optionsString) {
        Map<String, List<String>> options = new LinkedHashMap<>();
        if (optionsString == null || optionsString.isEmpty()) {
            return options;
        }

        String[] parts = optionsString.split("],");
        for (String part : parts) {
            String[] keyValue = part.split("=\\[");
            String key = keyValue[0].trim();
            String valuesString = keyValue[1].replaceAll("[\\[\\]]", "").trim();
            List<String> values = Arrays.stream(valuesString.split(","))
                                       .map(String::trim)
                                       .collect(Collectors.toList());
            options.put(key, values);
        }
        return options;
    }

    /**
     * Generates a dynamic test matrix of all combinations.
     */
    private static void generateTestMatrix() {
        testMatrix = new ArrayList<>();
        List<String> keys = new ArrayList<>(testConfigOptions.keySet());
        generateCombinations(keys, 0, new HashMap<>());
    }

    /**
     * Recursive helper method to generate all test combinations.
     */
    private static void generateCombinations(List<String> keys, int keyIndex, Map<String, String> currentCombination) {
        if (keyIndex == keys.size()) {
            testMatrix.add(new HashMap<>(currentCombination));
            return;
        }

        String currentKey = keys.get(keyIndex);
        List<String> values = testConfigOptions.get(currentKey);

        for (String value : values) {
            currentCombination.put(currentKey, value);
            generateCombinations(keys, keyIndex + 1, currentCombination);
            currentCombination.remove(currentKey);
        }
    }


    /**
     * Prints all the configuration variables at the beginning of the test.
     */
    private static void printConfigHeader() {
        System.out.println("Configuration Loaded:");
        System.out.println("---------------------------------");
        System.out.println("  Database URL: " + DB_URL);
        System.out.println("  Database User: " + DB_USER);
        System.out.println("  Test Duration: " + TEST_DURATION_SECONDS + " seconds");
        System.out.println("  Threads: " + NUM_THREADS);
        System.out.println("  Statements Per Thread: " + STATEMENTS_PER_THREAD);
        System.out.print("  Connector Versions: ");
        for (String version : connectorVersions) {
            System.out.print(version + " ");
        }
        System.out.println();
        System.out.print("  Monitored Status Variables: ");
        for (String var : monitorStatusVariables) {
            System.out.print(var + " ");
        }
        System.out.println();
        System.out.println("  Test Types: " + Arrays.stream(testTypes).collect(Collectors.joining(", ")));
        System.out.println("  Use Prepared Statements: " + USE_PREPARED_STATEMENTS);
        System.out.println("  Use Transactions: " + USE_TRANSACTIONS);
        System.out.println("  Transaction Size: " + TRANSACTION_SIZE);
        System.out.println("  Thread Sleep Time: " + THREAD_SLEEP_TIME + "ms");
        System.out.println("  Monitor Sleep Time: " + MONITOR_SLEEP_TIME + "ms");
        System.out.println("  Log to File: " + LOG_TO_FILE);
        System.out.println("  Log Filename Timestamp: " + LOG_FILENAME_TIMESTAMP);
        System.out.println("---------------------------------");
    }

    /**
     * Queries the MariaDB server for the value of a global status variable.
     * This method opens a separate, temporary connection to get the count.
     * @param variableName The name of the global status variable to query.
     * @return The current value of the status variable, or -1 if an error occurs.
     */
    private static long getGlobalStatusVariable(String variableName) {
        try (Connection conn = DriverManager.getConnection(DB_URL, DB_USER, DB_PASSWORD);
             Statement stmt = conn.createStatement();
             ResultSet rs = stmt.executeQuery("SHOW GLOBAL STATUS LIKE '" + variableName + "'")) {
            if (rs.next()) {
                return rs.getLong("Value");
            }
        } catch (SQLException e) {
            return -1;
        }
        return -1;
    }
}

