/*
    foreignkeys2.c - test for bug 4518

    This file is in the public domain.
*/

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <sql.h>
#include <sqlext.h>


/*
    You might want to change these.  You can also call the program as:

        progname [-d DSN] [-u USER] [-p PASSWD] [other-args]

    to override these values.  For example, call it as:

        progname -d myodbc3                -- use the myodbc3 dsn
        progname -u abc                    -- use the test dsn, with user abc
        progname -d test -u abc -p 0U8I2   -- ditto, with password
*/

static char *dsn = "test";
static char *user = NULL;     /* Use DSN default */
static char *pass = NULL;     /* Use DSN default */
static int argc;
static char ** argv;

/* Quick hack to provide buffers for query results - adjust as needed */
#define MAX_NAME_LEN 64
#define MAX_COLUMNS 16
#define MAX_ROW_DATA_LEN 1024

/* Max length for ODBC diagnostic messages */
#ifdef SQL_MAX_MESSAGE_LENGTH
#   define MAX_MESSAGE_LENGTH SQL_MAX_MESSAGE_LENGTH
#else
#   define MAX_MESSAGE_LENGTH 1024
#endif

/* Thes are initialized in db_connect() */
static SQLHENV henv;
static SQLHDBC hdbc;
static SQLHSTMT hstmt;


/*
    Convenience macros for error checking:
        env_die         Print diagnostic info from global henv and exit
        dbc_die         Print diagnostic info from global hdbc and exit
        stmt_die        Print diagnostic info from global hstmt and exit

        env_warn, dbc_warn, stmt_warn
                        Print diagnostic info, but do not exit
*/

#define env_die(rc) \
        _error((rc), SQL_HANDLE_ENV, (henv), 1, __FILE__, __LINE__)
#define env_warn(rc) \
        _error((rc), SQL_HANDLE_ENV, (henv), 0, __FILE__, __LINE__)

#define dbc_die(rc) \
        _error((rc), SQL_HANDLE_DBC, (hdbc), 1, __FILE__, __LINE__)
#define dbc_warn(rc) \
        _error((rc), SQL_HANDLE_DBC, (hdbc), 0, __FILE__, __LINE__)

#define stmt_die(rc) \
        _error((rc), SQL_HANDLE_STMT, (hstmt), 1, __FILE__, __LINE__)
#define stmt_warn(rc) \
        _error((rc), SQL_HANDLE_STMT, (hstmt), 0, __FILE__, __LINE__)

static void _error(SQLRETURN rc, SQLSMALLINT htype, SQLHANDLE handle,
        int is_fatal, const char *file, const unsigned int line);

int rc_err(SQLRETURN rc) { return (rc < 0); }
int rc_continue(SQLRETURN rc) { return (!rc_err(rc) && rc != SQL_NO_DATA); }


/*
    Some basic tests
*/

void db_tests(void)
{

/*
 * This code snagged from MSDN ODBC API Reference
 */

#define TAB_LEN SQL_MAX_TABLE_NAME_LEN + 1
#define COL_LEN SQL_MAX_COLUMN_NAME_LEN + 1

    UCHAR szTable[TAB_LEN];              /* Table to display */

    UCHAR szPkTable[TAB_LEN];   /* Primary key table name */
    UCHAR szFkTable[TAB_LEN];   /* Foreign key table name */
    UCHAR szPkCol[COL_LEN];     /* Primary key column */
    UCHAR szFkCol[COL_LEN];     /* Foreign key column */

    SQLINTEGER    cbPkTable, cbPkCol, cbFkTable, cbFkCol, cbKeySeq;
    SQLSMALLINT   iKeySeq;
    SQLRETURN     rc;

    /* Bind the columns that describe the primary and foreign keys. */
    /* Ignore the table schema, name, and catalog for this example. */

    rc = SQLBindCol(hstmt, 3, SQL_C_CHAR, szPkTable, TAB_LEN, &cbPkTable);
    if (rc_err(rc)) stmt_die(rc);
    rc = SQLBindCol(hstmt, 4, SQL_C_CHAR, szPkCol, COL_LEN, &cbPkCol);
    if (rc_err(rc)) stmt_die(rc);
    rc = SQLBindCol(hstmt, 5, SQL_C_SSHORT, &iKeySeq, TAB_LEN, &cbKeySeq);
    if (rc_err(rc)) stmt_die(rc);
    rc = SQLBindCol(hstmt, 7, SQL_C_CHAR, szFkTable, TAB_LEN, &cbFkTable);
    if (rc_err(rc)) stmt_die(rc);
    rc = SQLBindCol(hstmt, 8, SQL_C_CHAR, szFkCol, COL_LEN, &cbFkCol);
    if (rc_err(rc)) stmt_die(rc);

    strcpy(szTable, argc ? argv[0] : "ref1");

    /* Get all the foreign keys in szTable. */

    printf("Listing all foreign keys in table '%s':\n", szTable);
    rc = SQLForeignKeys(hstmt,
            NULL, 0,             /* Primary catalog */
            NULL, 0,             /* Primary schema */
            NULL, 0,             /* Primary table */
            NULL, 0,             /* Foreign catalog */
            NULL, 0,             /* Foreign schema */
            szTable, SQL_NTS);   /* Foreign table */

    while (rc_continue(rc = SQLFetch(hstmt))) {
        fprintf(stdout, "%-s ( %-s )--> %-s ( %-s )\n", szFkTable,
                szFkCol, szPkTable, szPkCol);
    }

    if (rc != SQL_NO_DATA) stmt_die(rc);
}


void db_connect(void);
void db_disconnect(void);
void usage(const char *msg, const char *arg);

int main(int argc_main, char **argv_main)
{
    /* Very basic arg handling */

    argc = argc_main;
    argv = argv_main;

    while (--argc > 0) {
        ++argv;
        if (argv[0][0] == '-') {
            if (--argc == 0)
                usage("Missing argument", argv[0]);

            switch (argv[0][1]) {
                case 'd':
                    dsn = argv[1];
                    break;
                case 'u':
                    user = argv[1];
                    break;
                case 'p':
                    pass = argv[1];
                    break;
            default:
                usage("Invalid argument", argv[0]);
                break;
            }
            ++argv;
        }
        else
            break;
    }

    db_connect();
    db_tests();
    db_disconnect();

    exit(EXIT_SUCCESS);
}


void usage(const char *msg, const char *arg)
{
    fprintf(stderr, "%s with argument %s\n", msg, arg);
    fputs("Usage: ./this-program [-d dsn] [-u user] [-p pass] [other-args]\n",
            stderr);
    exit(EXIT_FAILURE);
}

/*
    _error() should be called via the convenience macros.

    If there is any diagnostic info, print it.  If is_fatal is true,
    exit if rc == SQL_ERROR.
*/

static void _error(SQLRETURN rc, SQLSMALLINT htype, SQLHANDLE handle,
        int is_fatal, const char *file, const unsigned int line)
{
    SQLCHAR sql_state[6], message[MAX_MESSAGE_LENGTH + 1];
    SQLINTEGER native_error;
    SQLSMALLINT dummy;
    SQLRETURN new_rc;

    switch (rc) {
        case SQL_SUCCESS:
            fprintf(stderr, "\nSuccess\n");
            break;
        case SQL_ERROR:
        case SQL_SUCCESS_WITH_INFO:
        case SQL_NO_DATA:
        default:
            fprintf(stderr, "\nError at line %s:%u: ", file, line);

            new_rc = SQLGetDiagRec(htype, handle, 1, sql_state, &native_error,
                    message, SQL_MAX_MESSAGE_LENGTH, &dummy);
            if(new_rc == SQL_SUCCESS || new_rc == SQL_SUCCESS_WITH_INFO)
                fprintf(stderr, "[%s:%ld] %s\n",
                        sql_state, native_error, message);
            else
                fprintf(stderr, "Can't get error message (SQLGetDiagRec "
                        "returned rc[%d]; original rc[%d])\n", new_rc, rc);

            break;
    }

    if (is_fatal && rc < 0)
        exit(EXIT_FAILURE);
}


/*
    db_connect() initializes the global henv, hdbc and hstmt variables.
    It uses the global dsn, user, and pass variables.
*/

void db_connect(void)
{
    SQLRETURN rc;

    rc = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv);
    if (rc_err(rc)) env_die(rc);

    rc = SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION,
            (SQLPOINTER)SQL_OV_ODBC3, 0);
    if (rc_err(rc)) env_die(rc);

#ifdef VERBOSE
    printf("connecting (DSN='%s', USER='%s')\n",
            dsn, user ? user : "<default>");
#endif

    rc = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc);
    if (rc_err(rc)) env_die(rc);

    rc = SQLConnect(hdbc, dsn, SQL_NTS, user, SQL_NTS,  pass, SQL_NTS);
    if (rc_err(rc)) dbc_die(rc);

    rc = SQLSetConnectAttr(hdbc, SQL_ATTR_AUTOCOMMIT,
            (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
    if (rc_err(rc)) dbc_die(rc);


    rc = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt);
    if (rc_err(rc)) dbc_die(rc);
}


/*
    db_disconnect() releases all the resources acquired by db_connect()
*/

void db_disconnect(void)
{
    SQLRETURN rc;

    rc = SQLFreeStmt(hstmt, SQL_DROP);
    if (rc_err(rc)) dbc_die(rc);

    rc = SQLDisconnect(hdbc);
    if (rc_err(rc)) dbc_die(rc);

    rc = SQLFreeConnect(hdbc);
    if (rc_err(rc)) dbc_die(rc);

    rc = SQLFreeEnv(henv);
    if (rc_err(rc)) env_die(rc);
}
