#include "stdafx.h"
#include "MariaDB.h"
#include "Exception.h"
#include "StrUtils.h"
#include "Core/Convert.h"

namespace sql {

	// Number of bytes to allocate for strings as an initial guess. If most strings fit here, we
	// avoid back and forth to the library. We don't want to waste too much memory, however.
	static const size_t DEFAULT_STRING_SIZE = 32;

	MySQL::MySQL(Host c, MAYBE(Str *) user, MAYBE(Str *) password, Str *database)
		: MariaDBBase(c, user, password, database) {}

	MariaDB::MariaDB(Host c, MAYBE(Str *) user, MAYBE(Str *) password, Str *database)
		: MariaDBBase(c, user, password, database) {}

	features::DBFeatures MariaDB::features() const {
		// MariaDB has some additional features in addition to MySQL.
		return MariaDBBase::features()
			| features::insertReturning
			| features::deleteReturning;
	}

	MariaDBBase::MariaDBBase(Host c, MAYBE(Str *) user, MAYBE(Str *) password, Str *database) {
		handle = createMariaDBDriver(engine());
		api = handle->methods->api;

		// Set charset. utf8 is the default for MariaDB, but if the default ever changes it is better to be explicit.
		(*api->mysql_options)(handle, MYSQL_SET_CHARSET_NAME, "utf8mb4");

		try {
			const char *host = null;
			const char *pipe = null;
			unsigned int port = 0;

			if (Address *addr = c.isSocket()) {
				host = addr->withPort(0)->toS()->utf8_str();
				port = addr->port();
			} else if (Str *local = c.isLocal()) {
				host = null;
				pipe = local->utf8_str();
			} else {
				host = null;
				pipe = null;
			}

			if (!(*api->mysql_real_connect)(
					handle,
					host,
					user ? user->utf8_str() : null,
					password ? password->utf8_str() : null,
					database->utf8_str(),
					port,
					pipe,
					0)) {
				throwError();
			}
		} catch (...) {
			close();
			throw;
		}
	}

	MariaDBBase::~MariaDBBase() {
		close();
	}

	void MariaDBBase::close() {
		DBConnection::close();
		if (handle) {
			destroyMariaDBDriver(handle);
			handle = null;
		}
	}

	void MariaDBBase::throwError() {
		if (!handle)
			return;

		unsigned int code = (*api->mysql_errno)(handle);
		if (code == 0)
			return;

		const char *msg = (*api->mysql_error)(handle);
		throw new (this) SQLError(new (this) Str(toWChar(engine(), msg)), Maybe<Nat>(code));
	}


	Statement *MariaDBBase::prepare(Str *query) {
		return new (this) Stmt(this, query);
	}

	features::DBFeatures MariaDBBase::features() const {
		return features::none;
	}

	Array<Str *> *MariaDBBase::tables() {
		Statement *query = prepare(new (this) Str(S("SHOW TABLES;")));
		Array<Str *> *result = new (this) Array<Str *>();

		Statement::Result queryResult = query->execute();
		for (Maybe<Row> row = queryResult.next(); row.any(); row = queryResult.next()) {
			result->push(row.value().getStr(0));
		}

		query->finalize();

		return result;
	}

	Int MariaDBBase::queryLastRowId() {
		if (!lastRowIdQuery) {
			lastRowIdQuery = new (this) Stmt(this, new (this) Str(S("SELECT LAST_INSERT_ID();")));
		}

		Statement::Result result = lastRowIdQuery->execute();
		Int id = result.next().value().getInt(0);
		result.finalize();
		return id;
	}

	void MariaDBBase::query(const char *query) {
		clearFetch();
		if ((*api->mysql_query)(handle, query))
			throwError();
	}

	void MariaDBBase::query(Str *q) {
		query(q->utf8_str());
	}

	void MariaDBBase::query(QueryStr *q) {
		query(q->generate(visitor())->utf8_str());
	}

	void MariaDBBase::beginTransaction() {
		query("START TRANSACTION");
	}

	void MariaDBBase::endTransaction(Transaction::End end) {
		if (end == Transaction::endCommit) {
			query("COMMIT");
		} else if (end == Transaction::endRollback) {
			query("ROLLBACK");
		}
	}

	static bool findAndRemove(Str *&original, const wchar *lookFor) {
		const wchar *begin = original->c_str();
		const wchar *at = begin;
		while (*at) {
			// Skip whitespace.
			while (*at == ' ')
				at++;

			const wchar *end = at;
			while (*end && *end != ' ')
				end++;

			if (compareNoCase(at, end, lookFor)) {
				while (*end == ' ')
					end++;

				if (*end == '\0')
					while (at > begin && *(at - 1) == ' ')
						at--;

				Str *front = new (original) Str(begin, at);
				Str *back = new (original) Str(end);
				original = *front + back;
				return true;
			}

			at = end;
		}

		return false;
	}

	MAYBE(Schema *) MariaDBBase::schema(Str *table) {
		Statement *columnQuery = null;
		Statement *indexQuery = null;
		{
			QueryStrBuilder *colBuilder = new (this) QueryStrBuilder();
			colBuilder->put(S("SHOW COLUMNS FROM "));
			colBuilder->name(table);
			colBuilder->put(S(";"));
			columnQuery = prepare(colBuilder->build());

			QueryStrBuilder *indexBuilder = new (this) QueryStrBuilder();
			indexBuilder->put(S("SHOW INDEX FROM "));
			indexBuilder->name(table);
			indexBuilder->put(S(";"));
			indexQuery = prepare(indexBuilder->build());
		}

		try {
			Array<Schema::Column *> *columns = new (this) Array<Schema::Column *>();

			Statement::Result queryResult = columnQuery->execute();
			// Columns are: field, type, null, key, default, extra
			for (Maybe<Row> row = queryResult.next(); row.any(); row = queryResult.next()) {
				Row &v = row.value();

				Schema::Column *col = new (this) Schema::Column(v.getStr(0), QueryType::parse(v.getStr(1)));

				// Do we not know the type?
				if (col->type.empty())
					col->unknown = v.getStr(1);

				// Is it null?
				if (!compareNoCase(v.getStr(2), S("YES")))
					col->attributes |= Schema::notNull;

				// Check what kind of key it is.
				{
					Str *key = v.getStr(3);
					if (compareNoCase(key, S("PRI")))
						col->attributes |= Schema::primaryKey;
					else if (compareNoCase(key, S("UNI")))
						col->attributes |= Schema::unique;
					// Note: May also be MUL, which means that it is a part of an index.
				}

				if (!v.isNull(4)) {
					// Note: Any string values are already properly quoted by MariaDB.
					col->defaultValue = TO_S(this, v.getStr(4));
				}

				{
					Str *extra = v.getStr(5);
					if (findAndRemove(extra, S("AUTO_INCREMENT")))
						col->attributes |= Schema::autoIncrement;

					if (extra->any()) {
						if (col->unknown)
							col->unknown = TO_S(this, col->unknown << S(" ") << v.getStr(5));
						else
							col->unknown = v.getStr(5);
					}
				}

				columns->push(col);
			}

			Array<Schema::Index *> *indices = new (this) Array<Schema::Index *>();
			Map<Str *, Nat> *nameMap = new (this) Map<Str *, Nat>();

			queryResult = indexQuery->execute();
			// Columns are: table, non_unique, key_name, seq_in_index, column_name, collation, cardinality, ...
			for (Maybe<Row> row = queryResult.next(); row.any(); row = queryResult.next()) {
				Row &v = row.value();

				Str *name = v.getStr(2);
				// Don't add the index for the primary key.
				if (*name == S("PRIMARY"))
					continue;

				Nat id = nameMap->get(name, indices->count());
				if (id >= indices->count()) {
					indices->push(new (this) Schema::Index(name, new (this) Array<Str *>()));
					nameMap->put(name, id);
				}

				indices->at(id)->columns->push(v.getStr(4));
			}

			columnQuery->finalize();
			indexQuery->finalize();

			return new (this) Schema(table, columns, indices);
		} catch (SQLError *e) {
			if (columnQuery)
				columnQuery->finalize();
			if (indexQuery)
				indexQuery->finalize();

			// Check if the error was that the table does not exist.
			if (e->code.any() && e->code.value() == ER_NO_SUCH_TABLE)
				return null;

			// Otherwise, re-throw the exception.
			throw e;
		}
	}

	void MariaDBBase::migrate(Migration *m) {
		Transaction transaction(this);

		// Remove indices.
		for (Nat i = 0; i < m->indexRemove->count(); i++) {
			Migration::Index *index = m->indexRemove->at(i);
			query(TO_S(this, S("DROP INDEX `") << index->name << S("` ON `") << index->table << S("`")));
		}

		// Remove tables.
		for (Nat i = 0; i < m->tableRemove->count(); i++) {
			query(TO_S(this, S("DROP TABLE `") << m->tableRemove->at(i) << S("`")));
		}

		// Migrate tables.
		for (Nat i = 0; i < m->tableMigrate->count(); i++) {
			migrateTable(m->tableMigrate->at(i));
		}

		// Add tables.
		for (Nat i = 0; i < m->tableAdd->count(); i++) {
			Schema *table = m->tableAdd->at(i);

			QueryStrBuilder *b = new (this) QueryStrBuilder();
			table->toSQL(b);
			query(b->build());

			// Don't forget the associated indices!
			Array<Schema::Index *> *indices = table->indices();
			for (Nat j = 0; j < indices->count(); j++) {
				QueryStrBuilder *b = new (this) QueryStrBuilder();
				indices->at(j)->toSQL(b, table->name());
				query(b->build());
			}
		}

		// Add indices.
		for (Nat i = 0; i < m->indexAdd->count(); i++) {
			QueryStrBuilder *b = new (this) QueryStrBuilder();
			m->indexAdd->at(i)->toSQL(b);
			query(b->build());
		}

		transaction.commit();
	}

	static void addComma(QueryStrBuilder *to, Bool &first) {
		if (!first)
			to->put(S(", "));
		first = false;
	}

	void MariaDBBase::migrateTable(Migration::Table *m) {
		// Try to create a large query that does everything except the primary key constraints.
		// Hopefully the database will be able to do the changes as a unit in that case.
		QueryStrBuilder *b = new (this) QueryStrBuilder();
		b->put(S("ALTER TABLE "));
		b->name(m->table);
		b->put(S(" "));

		Bool firstPart = true;

		// Remove columns.
		for (Nat i = 0; i < m->colRemove->count(); i++) {
			addComma(b, firstPart);
			b->put(S("DROP COLUMN "));
			b->name(m->colRemove->at(i));
		}

		// Update columns.
		for (Nat i = 0; i < m->colMigrate->count(); i++) {
			Migration::ColAttrs *x = m->colMigrate->at(i);
			// On MySQL/MariaDB, this is quite easy. We just have to re-specify the column
			// definition and the DB will sort out the rest.
			addComma(b, firstPart);
			b->put(S("MODIFY COLUMN "));
			x->toSchema()->toSQL(b);
		}

		// Add columns.
		for (Nat i = 0; i < m->colAdd->count(); i++) {
			addComma(b, firstPart);
			b->put(S("ADD COLUMN "));
			m->colAdd->at(i)->toSQL(b);
		}

		// Execute the query!
		if (!firstPart)
			query(b->build());

		// Add back the primary key.
		if (m->updatePrimaryKeys) {
			// Note that the ability to do multiple things in one ALTER TABLE query is an extension.
			QueryStrBuilder *b = new (this) QueryStrBuilder();
			b->put(S("ALTER TABLE "));
			b->name(m->table);

			if (m->dropPrimaryKeys) {
				// The primary key is always named PRIMARY as per the MySQL manual.
				// Note that the columns with the primary keys may have already been removed at this point.
				b->put(S(" DROP INDEX IF EXISTS `PRIMARY`"));
			}

			if (m->primaryKeys->any()) {
				if (m->dropPrimaryKeys)
					b->put(S(","));

				b->put(S(" ADD PRIMARY KEY ("));
				b->name(m->primaryKeys->at(0));
				for (Nat i = 1; i < m->primaryKeys->count(); i++) {
					b->put(S(", "));
					b->name(m->primaryKeys->at(i));
				}
				b->put(S(")"));
			}

			query(b->build());
		}
	}


	/**
	 * Statement.
	 */

	static void freeBinds(Nat &count, MYSQL_BIND *&binds, Value *&values) {
		if (values) {
			for (Nat i = 0; i < count; i++)
				values[i].~Value();

			free(values);
			values = null;
		}

		if (binds) {
			free(binds);
			binds = null;
		}

		count = 0;
	}

	static void allocBinds(Nat count, MYSQL_BIND *&binds, Value *&values) {
		binds = (MYSQL_BIND *)calloc(count, sizeof(MYSQL_BIND));
		values = (Value *)calloc(count, sizeof(Value));

		for (Nat i = 0; i < count; i++)
			new (values + i) Value(binds + i);
	}

	static Nat stmts = 0;

	MariaDBBase::Stmt::Stmt(MariaDBBase *owner, Str *query)
		: SequentialStatement(owner), lastId(-1),
		  paramCount(0), paramBinds(null), paramValues(null),
		  resultCount(0), resultBinds(null), resultValues(null) {

		// Try to get a GC earlier if possible.
		const char *utf8Query = query->utf8_str();

		owner->clearFetch();

		stmt = (*owner->api->mysql_stmt_init)(owner->handle);
		if (!stmt)
			owner->throwError();

		try {
			if ((*owner->api->mysql_stmt_prepare)(stmt, utf8Query, -1)) {
				throwError();
			}

			stmts++;

			paramCount = (*owner->api->mysql_stmt_param_count)(stmt);
			if (paramCount > 0)
				allocBinds(paramCount, paramBinds, paramValues);

		} catch (...) {
			finalize();
			throw;
		}
	}

	// Technically a part of the DB connection, but logically belongs to the Stmt.
	void MariaDBBase::finalizeStmt(SequentialStatement *stmt) {
		Stmt *s = reinterpret_cast<Stmt *>(stmt);

		freeBinds(s->paramCount, s->paramBinds, s->paramValues);
		freeBinds(s->resultCount, s->resultBinds, s->resultValues);

		if (s->stmt) {
			(*api->mysql_stmt_close)(s->stmt);
			s->stmt = null;
		}
	}

	void MariaDBBase::Stmt::throwError() {
		if (!stmt)
			return;

		// This is mostly a safeguard - we likely already have exclusive access to the connection.
		owner()->clearFetch(this);

		unsigned int code = (*owner()->api->mysql_stmt_errno)(stmt);
		if (code == 0)
			return;

		const char *error = (*owner()->api->mysql_stmt_error)(stmt);
		throw new (this) SQLError(new (this) Str(toWChar(engine(), error)), Maybe<Nat>(code));
	}

	void MariaDBBase::Stmt::bind(Nat pos, Str *str) {
		if (pos < paramCount)
			paramValues[pos].setString(str);
	}

	void MariaDBBase::Stmt::bind(Nat pos, Bool b) {
		if (pos < paramCount)
			paramValues[pos].setInt(b ? 1 : 0);
	}

	void MariaDBBase::Stmt::bind(Nat pos, Int i) {
		if (pos < paramCount)
			paramValues[pos].setInt(i);
	}

	void MariaDBBase::Stmt::bind(Nat pos, Long l) {
		if (pos < paramCount)
			paramValues[pos].setInt(l);
	}

	void MariaDBBase::Stmt::bind(Nat pos, Float f) {
		if (pos < paramCount)
			paramValues[pos].setFloat(f);
	}

	void MariaDBBase::Stmt::bind(Nat pos, Double d) {
		if (pos < paramCount)
			paramValues[pos].setFloat(d);
	}

	void MariaDBBase::Stmt::bindNull(Nat pos) {
		if (pos < paramCount)
			paramValues[pos].setNull();
	}

	Bool MariaDBBase::Stmt::executeSeq() {
		// Make sure to reset the statement in case we could not do it in dispose (we should, unless
		// there is a bug).
		(*owner()->api->mysql_stmt_reset)(stmt);

		if (paramBinds)
			(*owner()->api->mysql_stmt_bind_param)(stmt, paramBinds);

		// stmtHasData = true;
		if ((*owner()->api->mysql_stmt_execute)(stmt))
			throwError();

		lastChanges = Nat((*owner()->api->mysql_stmt_affected_rows)(stmt));

		MYSQL_RES *metadata = (*owner()->api->mysql_stmt_result_metadata)(stmt);
		if (metadata) {
			resultCount = (*owner()->api->mysql_num_fields)(metadata);

			MYSQL_FIELD *columns = (*owner()->api->mysql_fetch_fields)(metadata);
			allocBinds(resultCount, resultBinds, resultValues);

			for (Nat i = 0; i < resultCount; i++) {
				switch (columns[i].type) {
				case MYSQL_TYPE_STRING:
				case MYSQL_TYPE_VAR_STRING:
				case MYSQL_TYPE_BLOB:
					// Allocate some size. We ask for the real size later on.
					resultValues[i].setString(DEFAULT_STRING_SIZE);
					break;

				case MYSQL_TYPE_FLOAT:
				case MYSQL_TYPE_DOUBLE:
					resultValues[i].setFloat(0);
					break;

				default:
					if (IS_NUM(columns[i].type)) {
						// Treat all integers as 'Long'.
						if (columns[i].flags & UNSIGNED_FLAG)
							resultValues[i].setUInt(0);
						else
							resultValues[i].setInt(0);
					}
					break;
				}
			}

			(*owner()->api->mysql_free_result)(metadata);

			// Bind the result now.
			if ((*owner()->api->mysql_stmt_bind_result)(stmt, resultBinds))
				throwError();

			// Let the world know that we have results to fetch!
			return true;
		} else {
			// No result, the metadata returned null.

			// Dispose the result now already so that we can get the last row id:
			disposeResult();
			lastId = owner()->queryLastRowId();

			// No results to fetch.
			return false;
		}

	}

	Maybe<Row> MariaDBBase::Stmt::nextRowSeq() {
		if (!resultValues)
			return Maybe<Row>();

		// Fetch the next row.
		int result = (*owner()->api->mysql_stmt_fetch)(stmt);
		if (result == 1) {
			throwError();
		} else if (result == MYSQL_NO_DATA) {
			// End of the query!
			return Maybe<Row>();
		}

		// Note: If we get MYSQL_DATA_TRUNCATED, we know that something was truncated. We don't need
		// to check specifically for this here, since we always check string lengths.

		// Extract the results.
		Row::Builder builder = Row::builder(engine(), resultCount);

		for (Nat i = 0; i < resultCount; i++) {
			Value &v = resultValues[i];

			if (v.isNull()) {
				builder.pushNull();
			} else if (v.isInt()) {
				builder.push(Long(v.getInt()));
			} else if (v.isUInt()) {
				builder.push(Long(v.getUInt())); // TODO: Maybe support unsigned values also?
			} else if (v.isFloat()) {
				builder.push(v.getFloat());
			} else if (v.isString()) {
				size_t sz = v.isTruncated();
				if (sz) {
					// Re-fetch to see if it was truncated.
					v.setString(sz);
					(*owner()->api->mysql_stmt_fetch_column)(stmt, resultBinds + i, i, 0);
				}

				builder.push(v.getString(engine()));
			} else {
				StrBuf *msg = new (this) StrBuf();
				*msg << S("Unknown column type for column ") << i << S("!");
				throw new (this) SQLError(msg->toS());
			}
		}

		return Maybe<Row>(Row(builder));
	}

	void MariaDBBase::Stmt::disposeResultSeq() {
		freeBinds(resultCount, resultBinds, resultValues);

		if (stmt) {
			// Note: First one is probably fine to do always.
			(*owner()->api->mysql_stmt_free_result)(stmt);

			// This needs to be done in sequence.
			if (owner()->isFetching(this))
				(*owner()->api->mysql_stmt_reset)(stmt);
		}
	}


	/**
	 * Visitor.
	 */

	QueryStr::Visitor *MariaDBBase::visitor() const {
		return new (this) Visitor();
	}

	MariaDBBase::Visitor::Visitor() {}

	void MariaDBBase::Visitor::name(StrBuf *to, Str *name) {
		*to << S("`") << name << S("`");
	}

	void MariaDBBase::Visitor::type(StrBuf *to, QueryType type) {
		Maybe<Nat> size = type.size();

		if (type.sameType(QueryType::text())) {
			if (size.any()) {
				*to << S("VARCHAR") << S("(") << size.value() << S(")");
			} else {
				*to << S("TEXT");
			}
		} else if (type.sameType(QueryType::integer())) {
			if (size.any()) {
				if (size.value() <= 1)
					*to << S("TINYINT");
				else if (size.value() <= 2)
					*to << S("SMALLINT");
				else if (size.value() <= 3)
					*to << S("MEDIUMINT");
				else if (size.value() > 4)
					*to << S("BIGINT");
				else
					*to << S("INTEGER");
			} else {
				*to << S("INTEGER");
			}
		} else if (type.sameType(QueryType::real())) {
			if (size.any() && size.value() > 4) {
				*to << S("DOUBLE");
			} else {
				*to << S("REAL");
			}
		} else {
			throw new (this) SQLError(TO_S(this, S("Unsupported type: ") << type << S(".")));
		}
	}

	void MariaDBBase::Visitor::autoIncrement(StrBuf *to) {
		*to << S("AUTO_INCREMENT");
	}

}
