From 735fcbf3111fbaf5a50d1d1bc25df8ce94a05fe7 Mon Sep 17 00:00:00 2001 From: Gavin Bises Date: Sat, 21 Feb 2015 21:29:59 -0500 Subject: [PATCH] Add first draft of protocol extension for registration Stub for registration command handling in server First draft of handling registration requests WIP (will be rebased) clean up bad imports (rebase this later) Finish checkUserIsBanned method Add username validity check Check servatrice registration settings WIP Finish(?) server side of registration Needs testing Fix switch case compile failure I have no idea why I have to do this WIP for registration testing python script Stub register script initial attempt Rearrange register script First try at sending reg register.py sends commands correctly now Add more debug to register.py Pack bytes the right way - servatrice can parse py script sends now register.py should be working now Parse xml hack correctly Log registration enabled settings on server start Insert gender correctly on register Show tcpserver error message on failed gameserver listen Fail startup if db configured and can't be opened. TIL qt5 comes without mysql by default in homebrew... --- common/pb/CMakeLists.txt | 1 + common/pb/response.proto | 9 + common/pb/response_register.proto | 9 + common/pb/session_commands.proto | 21 +++ common/server.cpp | 39 +++++ common/server.h | 18 ++ common/server_database_interface.h | 4 + common/server_protocolhandler.cpp | 49 ++++++ common/server_protocolhandler.h | 2 + servatrice/scripts/.gitignore | 1 + servatrice/scripts/mk_pypb.sh | 10 ++ servatrice/scripts/register.py | 73 ++++++++ servatrice/servatrice.ini.example | 7 + servatrice/src/servatrice.cpp | 24 ++- .../src/servatrice_database_interface.cpp | 162 +++++++++++++----- .../src/servatrice_database_interface.h | 13 +- 16 files changed, 394 insertions(+), 48 deletions(-) create mode 100644 common/pb/response_register.proto create mode 100644 servatrice/scripts/.gitignore create mode 100755 servatrice/scripts/mk_pypb.sh create mode 100755 servatrice/scripts/register.py diff --git a/common/pb/CMakeLists.txt b/common/pb/CMakeLists.txt index adf1e1b8..e698c95f 100644 --- a/common/pb/CMakeLists.txt +++ b/common/pb/CMakeLists.txt @@ -123,6 +123,7 @@ SET(PROTO_FILES response_join_room.proto response_list_users.proto response_login.proto + response_register.proto response_replay_download.proto response_replay_list.proto response.proto diff --git a/common/pb/response.proto b/common/pb/response.proto index 0f13c415..302b1801 100644 --- a/common/pb/response.proto +++ b/common/pb/response.proto @@ -24,6 +24,14 @@ message Response { RespAccessDenied = 20; RespUsernameInvalid = 21; RespRegistrationRequired = 22; + RespRegistrationAccepted = 23; // Server agrees to process client's registration request + RespUserAlreadyExists = 24; // Client attempted to register a name which is already registered + RespEmailRequiredToRegister = 25; // Server requires email to register accounts but client did not provide one + RespServerDoesNotUseAuth = 26; // Client attempted to register but server does not use authentication + RespTooManyRequests = 27; // Server refused to complete command because client has sent too many too quickly + RespAccountNotActivated = 28; // Client attempted to log into a registered username but the account hasn't been activated + RespRegistrationDisabled = 29; // Server does not allow clients to register + RespRegistrationFailed = 30; // Server accepted a reg request but failed to perform the registration } enum ResponseType { JOIN_ROOM = 1000; @@ -35,6 +43,7 @@ message Response { DECK_LIST = 1006; DECK_DOWNLOAD = 1007; DECK_UPLOAD = 1008; + REGISTER = 1009; REPLAY_LIST = 1100; REPLAY_DOWNLOAD = 1101; } diff --git a/common/pb/response_register.proto b/common/pb/response_register.proto new file mode 100644 index 00000000..9c6998ef --- /dev/null +++ b/common/pb/response_register.proto @@ -0,0 +1,9 @@ +import "response.proto"; + +message Response_Register { + extend Response { + optional Response_Register ext = 1009; + } + optional string denied_reason_str = 1; + optional uint64 denied_end_time = 2; +} \ No newline at end of file diff --git a/common/pb/session_commands.proto b/common/pb/session_commands.proto index bbb5e81d..be5cc1db 100644 --- a/common/pb/session_commands.proto +++ b/common/pb/session_commands.proto @@ -1,3 +1,5 @@ +import "serverinfo_user.proto"; + message SessionCommand { enum SessionCommandType { PING = 1000; @@ -16,6 +18,7 @@ message SessionCommand { DECK_UPLOAD = 1013; LIST_ROOMS = 1014; JOIN_ROOM = 1015; + REGISTER = 1016; REPLAY_LIST = 1100; REPLAY_DOWNLOAD = 1101; REPLAY_MODIFY_MATCH = 1102; @@ -94,3 +97,21 @@ message Command_JoinRoom { } optional uint32 room_id = 1; } + +// User wants to register a new account +message Command_Register { + extend SessionCommand { + optional Command_Register ext = 1016; + } + // User name client wants to register + required string user_name = 1; + // Hashed password to be inserted into database + required string password = 2; + // Email address of the client for user validation + optional string email = 3; + // Gender of the user + optional ServerInfo_User.Gender gender = 4; + // Country code of the user. 2 letter ISO format + optional string country = 5; + optional string real_name = 6; +} diff --git a/common/server.cpp b/common/server.cpp index 4b65739b..7fd083e2 100644 --- a/common/server.cpp +++ b/common/server.cpp @@ -175,6 +175,45 @@ AuthenticationResult Server::loginUser(Server_ProtocolHandler *session, QString return authState; } +RegistrationResult Server::registerUserAccount(const QString &ipAddress, const Command_Register &cmd, QString &banReason, int &banSecondsRemaining) +{ + // TODO + + if (!registrationEnabled) + return RegistrationDisabled; + + QString emailAddress = QString::fromStdString(cmd.email()); + if (requireEmailForRegistration && emailAddress.isEmpty()) + return EmailRequired; + + Server_DatabaseInterface *databaseInterface = getDatabaseInterface(); + + // TODO: Move this method outside of the db interface + QString userName = QString::fromStdString(cmd.user_name()); + if (!databaseInterface->usernameIsValid(userName)) + return InvalidUsername; + + if (databaseInterface->checkUserIsBanned(ipAddress, userName, banReason, banSecondsRemaining)) + return ClientIsBanned; + + if (tooManyRegistrationAttempts(ipAddress)) + return TooManyRequests; + + QString realName = QString::fromStdString(cmd.real_name()); + ServerInfo_User_Gender gender = cmd.gender(); + QString country = QString::fromStdString(cmd.country()); + QString passwordSha512 = QString::fromStdString(cmd.password()); + bool regSucceeded = databaseInterface->registerUser(userName, realName, gender, passwordSha512, emailAddress, country, false); + + return regSucceeded ? Accepted : Failed; +} + +bool Server::tooManyRegistrationAttempts(const QString &ipAddress) +{ + // TODO: implement + return false; +} + void Server::addPersistentPlayer(const QString &userName, int roomId, int gameId, int playerId) { QWriteLocker locker(&persistentPlayersLock); diff --git a/common/server.h b/common/server.h index 896f6ef8..ca039350 100644 --- a/common/server.h +++ b/common/server.h @@ -7,6 +7,7 @@ #include #include #include +#include "pb/commands.pb.h" #include "pb/serverinfo_user.pb.h" #include "server_player_reference.h" @@ -28,6 +29,7 @@ class CommandContainer; class Command_JoinGame; enum AuthenticationResult { NotLoggedIn = 0, PasswordRight = 1, UnknownUser = 2, WouldOverwriteOldSession = 3, UserIsBanned = 4, UsernameInvalid = 5, RegistrationRequired = 6 }; +enum RegistrationResult { Accepted = 0, UserAlreadyExists = 1, EmailRequired = 2, UnauthenticatedServer = 3, TooManyRequests = 4, InvalidUsername = 5, ClientIsBanned = 6, RegistrationDisabled = 7, Failed = 8}; class Server : public QObject { @@ -44,6 +46,19 @@ public: ~Server(); void setThreaded(bool _threaded) { threaded = _threaded; } AuthenticationResult loginUser(Server_ProtocolHandler *session, QString &name, const QString &password, QString &reason, int &secondsLeft); + + /** + * Registers a user account. + * @param ipAddress The address of the connection from the user + * @param userName The username to attempt to register + * @param emailAddress The email address to associate with the new account (and to use for activation) + * @param banReason If the client is banned, the reason for the ban will be included in this string. + * @param banSecondsRemaining If the client is banned, the time left will be included in this. 0 if the ban is permanent. + * @return RegistrationResult member indicating whether it succeeded or failed. + */ + RegistrationResult registerUserAccount(const QString &ipAddress, const Command_Register &cmd, QString &banReason, int &banSecondsRemaining); + + bool tooManyRegistrationAttempts(const QString &ipAddress); const QMap &getRooms() { return rooms; } Server_AbstractUserInterface *findUser(const QString &userName) const; @@ -115,6 +130,9 @@ protected: int getUsersCount() const; int getGamesCount() const; void addRoom(Server_Room *newRoom); + + bool registrationEnabled; + bool requireEmailForRegistration; }; #endif diff --git a/common/server_database_interface.h b/common/server_database_interface.h index 518bd5e6..2568f195 100644 --- a/common/server_database_interface.h +++ b/common/server_database_interface.h @@ -13,6 +13,7 @@ public: : QObject(parent) { } virtual AuthenticationResult checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &secondsLeft) = 0; + virtual bool checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining) { return false; } virtual bool userExists(const QString & /* user */) { return false; } virtual QMap getBuddyList(const QString & /* name */) { return QMap(); } virtual QMap getIgnoreList(const QString & /* name */) { return QMap(); } @@ -23,6 +24,7 @@ public: virtual DeckList *getDeckFromDatabase(int /* deckId */, int /* userId */) { return 0; } virtual qint64 startSession(const QString & /* userName */, const QString & /* address */) { return 0; } + virtual bool usernameIsValid(const QString &userName) { return true; }; public slots: virtual void endSession(qint64 /* sessionId */ ) { } public: @@ -35,9 +37,11 @@ public: virtual bool userSessionExists(const QString & /* userName */) { return false; } virtual bool getRequireRegistration() { return false; } + virtual bool registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active = false) { return false; } enum LogMessage_TargetType { MessageTargetRoom, MessageTargetGame, MessageTargetChat, MessageTargetIslRoom }; virtual void logMessage(const int /* senderId */, const QString & /* senderName */, const QString & /* senderIp */, const QString & /* logMessage */, LogMessage_TargetType /* targetType */, const int /* targetId */, const QString & /* targetName */) { }; + bool checkUserIsBanned(Server_ProtocolHandler *session, QString &banReason, int &banSecondsRemaining); }; #endif diff --git a/common/server_protocolhandler.cpp b/common/server_protocolhandler.cpp index 46d79139..cc5dd7f9 100644 --- a/common/server_protocolhandler.cpp +++ b/common/server_protocolhandler.cpp @@ -9,6 +9,7 @@ #include "pb/commands.pb.h" #include "pb/response.pb.h" #include "pb/response_login.pb.h" +#include "pb/response_register.pb.h" #include "pb/response_list_users.pb.h" #include "pb/response_get_games_of_user.pb.h" #include "pb/response_get_user_info.pb.h" @@ -134,12 +135,17 @@ Response::ResponseCode Server_ProtocolHandler::processSessionCommandContainer(co SessionCommand debugSc(sc); debugSc.MutableExtension(Command_Login::ext)->clear_password(); logDebugMessage(QString::fromStdString(debugSc.ShortDebugString())); + } else if (num == SessionCommand::REGISTER) { + SessionCommand logSc(sc); + logSc.MutableExtension(Command_Register::ext)->clear_password(); + logDebugMessage(QString::fromStdString(logSc.ShortDebugString())); } else logDebugMessage(QString::fromStdString(sc.ShortDebugString())); } switch ((SessionCommand::SessionCommandType) num) { case SessionCommand::PING: resp = cmdPing(sc.GetExtension(Command_Ping::ext), rc); break; case SessionCommand::LOGIN: resp = cmdLogin(sc.GetExtension(Command_Login::ext), rc); break; + case SessionCommand::REGISTER: resp = cmdRegisterAccount(sc.GetExtension(Command_Register::ext), rc); break; case SessionCommand::MESSAGE: resp = cmdMessage(sc.GetExtension(Command_Message::ext), rc); break; case SessionCommand::GET_GAMES_OF_USER: resp = cmdGetGamesOfUser(sc.GetExtension(Command_GetGamesOfUser::ext), rc); break; case SessionCommand::GET_USER_INFO: resp = cmdGetUserInfo(sc.GetExtension(Command_GetUserInfo::ext), rc); break; @@ -413,6 +419,49 @@ Response::ResponseCode Server_ProtocolHandler::cmdLogin(const Command_Login &cmd return Response::RespOk; } +Response::ResponseCode Server_ProtocolHandler::cmdRegisterAccount(const Command_Register &cmd, ResponseContainer &rc) +{ + qDebug() << "Got register command: " << QString::fromStdString(cmd.user_name()); + + QString banReason; + int banSecondsRemaining; + RegistrationResult result = + server->registerUserAccount( + this->getAddress(), + cmd, + banReason, + banSecondsRemaining); + qDebug() << "Register command result:" << result; + + switch (result) { + case RegistrationDisabled: + return Response::RespRegistrationDisabled; + case Accepted: + return Response::RespRegistrationAccepted; + case UserAlreadyExists: + return Response::RespUserAlreadyExists; + case EmailRequired: + return Response::RespEmailRequiredToRegister; + case UnauthenticatedServer: + return Response::RespServerDoesNotUseAuth; + case TooManyRequests: + return Response::RespTooManyRequests; + case InvalidUsername: + return Response::RespUsernameInvalid; + case Failed: + return Response::RespRegistrationFailed; + case ClientIsBanned: + Response_Register *re = new Response_Register; + re->set_denied_reason_str(banReason.toStdString()); + if (banSecondsRemaining != 0) + re->set_denied_end_time(QDateTime::currentDateTime().addSecs(banSecondsRemaining).toTime_t()); + rc.setResponseExtension(re); + return Response::RespUserIsBanned; + } + + return Response::RespInvalidCommand; +} + Response::ResponseCode Server_ProtocolHandler::cmdMessage(const Command_Message &cmd, ResponseContainer &rc) { if (authState == NotLoggedIn) diff --git a/common/server_protocolhandler.h b/common/server_protocolhandler.h index 5389b724..7ea34cf8 100644 --- a/common/server_protocolhandler.h +++ b/common/server_protocolhandler.h @@ -28,6 +28,7 @@ class AdminCommand; class Command_Ping; class Command_Login; +class Command_Register; class Command_Message; class Command_ListUsers; class Command_GetGamesOfUser; @@ -59,6 +60,7 @@ private: Response::ResponseCode cmdPing(const Command_Ping &cmd, ResponseContainer &rc); Response::ResponseCode cmdLogin(const Command_Login &cmd, ResponseContainer &rc); + Response::ResponseCode cmdRegisterAccount(const Command_Register &cmd, ResponseContainer &rc); Response::ResponseCode cmdMessage(const Command_Message &cmd, ResponseContainer &rc); Response::ResponseCode cmdGetGamesOfUser(const Command_GetGamesOfUser &cmd, ResponseContainer &rc); Response::ResponseCode cmdGetUserInfo(const Command_GetUserInfo &cmd, ResponseContainer &rc); diff --git a/servatrice/scripts/.gitignore b/servatrice/scripts/.gitignore new file mode 100644 index 00000000..eb47a082 --- /dev/null +++ b/servatrice/scripts/.gitignore @@ -0,0 +1 @@ +pypb/ diff --git a/servatrice/scripts/mk_pypb.sh b/servatrice/scripts/mk_pypb.sh new file mode 100755 index 00000000..e11d237d --- /dev/null +++ b/servatrice/scripts/mk_pypb.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +SRC_DIR=../../common/pb/ +DST_DIR=./pypb + +rm -rf "$DST_DIR" +mkdir -p "$DST_DIR" +protoc -I=$SRC_DIR --python_out=$DST_DIR $SRC_DIR/*.proto +touch "$DST_DIR/__init__.py" + diff --git a/servatrice/scripts/register.py b/servatrice/scripts/register.py new file mode 100755 index 00000000..427c9e72 --- /dev/null +++ b/servatrice/scripts/register.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +import socket, sys, struct, time + +from pypb.server_message_pb2 import ServerMessage +from pypb.session_commands_pb2 import Command_Register as Reg +from pypb.commands_pb2 import CommandContainer as Cmd +from pypb.event_server_identification_pb2 import Event_ServerIdentification as ServerId +from pypb.response_pb2 import Response + +HOST = "localhost" +PORT = 4747 + +CMD_ID = 1 + +def build_reg(): + global CMD_ID + cmd = Cmd() + sc = cmd.session_command.add() + + reg = sc.Extensions[Reg.ext] + reg.user_name = "testUser" + reg.email = "test@example.com" + reg.password = "password" + + cmd.cmd_id = CMD_ID + CMD_ID += 1 + return cmd + +def send(msg): + packed = struct.pack('>I', len(msg)) + sock.sendall(packed) + sock.sendall(msg) + +def print_resp(resp): + print "<<<" + print repr(resp) + m = ServerMessage() + m.ParseFromString(bytes(resp)) + print m + +def recv(sock): + print "< header" + header = sock.recv(4) + msg_size = struct.unpack('>I', header)[0] + print "< ", msg_size + raw_msg = sock.recv(msg_size) + print_resp(raw_msg) + +if __name__ == "__main__": + address = (HOST, PORT) + sock = socket.socket() + + print "Connecting to server ", address + sock.connect(address) + + # hack for old xml clients - server expects this and discards first message + print ">>> xml hack" + xmlClientHack = Cmd().SerializeToString() + send(xmlClientHack) + print sock.recv(60) + + recv(sock) + + print ">>> register" + r = build_reg() + print r + msg = r.SerializeToString() + send(msg) + recv(sock) + + print "Done" + diff --git a/servatrice/servatrice.ini.example b/servatrice/servatrice.ini.example index bd1cad55..f75a8958 100644 --- a/servatrice/servatrice.ini.example +++ b/servatrice/servatrice.ini.example @@ -54,6 +54,13 @@ password=123456 ; Accept only registered users? default is 0 (accept unregistered users) regonly=0 +[registration] + +; Servatrice can process registration requests to add new users on the fly. +; Enable this feature? Default false. +;enabled=false +; Require users to provide an email address in order to register. Default true. +;requireemail=true [database] diff --git a/servatrice/src/servatrice.cpp b/servatrice/src/servatrice.cpp index 66a9c977..e4430912 100644 --- a/servatrice/src/servatrice.cpp +++ b/servatrice/src/servatrice.cpp @@ -160,6 +160,13 @@ bool Servatrice::initServer() authenticationMethod = AuthenticationNone; } + registrationEnabled = settingsCache->value("registration/enabled", false).toBool(); + requireEmailForRegistration = settingsCache->value("registration/requireemail", true).toBool(); + + qDebug() << "Registration enabled: " << registrationEnabled; + if (registrationEnabled) + qDebug() << "Require email address to register: " << requireEmailForRegistration; + QString dbTypeStr = settingsCache->value("database/type").toString(); if (dbTypeStr == "mysql") databaseType = DatabaseMySql; @@ -172,12 +179,17 @@ bool Servatrice::initServer() if (databaseType != DatabaseNone) { settingsCache->beginGroup("database"); dbPrefix = settingsCache->value("prefix").toString(); - servatriceDatabaseInterface->initDatabase("QMYSQL", - settingsCache->value("hostname").toString(), - settingsCache->value("database").toString(), - settingsCache->value("user").toString(), - settingsCache->value("password").toString()); + bool dbOpened = + servatriceDatabaseInterface->initDatabase("QMYSQL", + settingsCache->value("hostname").toString(), + settingsCache->value("database").toString(), + settingsCache->value("user").toString(), + settingsCache->value("password").toString()); settingsCache->endGroup(); + if (!dbOpened) { + qDebug() << "Failed to open database"; + return false; + } updateServerList(); @@ -342,7 +354,7 @@ bool Servatrice::initServer() if (gameServer->listen(QHostAddress::Any, gamePort)) qDebug() << "Server listening."; else { - qDebug() << "gameServer->listen(): Error."; + qDebug() << "gameServer->listen(): Error:" << gameServer->errorString(); return false; } return true; diff --git a/servatrice/src/servatrice_database_interface.cpp b/servatrice/src/servatrice_database_interface.cpp index 77f31156..aabf2a2a 100644 --- a/servatrice/src/servatrice_database_interface.cpp +++ b/servatrice/src/servatrice_database_interface.cpp @@ -9,6 +9,7 @@ #include #include #include +#include Servatrice_DatabaseInterface::Servatrice_DatabaseInterface(int _instanceId, Servatrice *_server) : instanceId(_instanceId), @@ -34,7 +35,9 @@ void Servatrice_DatabaseInterface::initDatabase(const QSqlDatabase &_sqlDatabase } } -void Servatrice_DatabaseInterface::initDatabase(const QString &type, const QString &hostName, const QString &databaseName, const QString &userName, const QString &password) +bool Servatrice_DatabaseInterface::initDatabase(const QString &type, const QString &hostName, + const QString &databaseName, const QString &userName, + const QString &password) { sqlDatabase = QSqlDatabase::addDatabase(type, "main"); sqlDatabase.setHostName(hostName); @@ -42,7 +45,7 @@ void Servatrice_DatabaseInterface::initDatabase(const QString &type, const QStri sqlDatabase.setUserName(userName); sqlDatabase.setPassword(password); - openDatabase(); + return openDatabase(); } bool Servatrice_DatabaseInterface::openDatabase() @@ -102,11 +105,52 @@ bool Servatrice_DatabaseInterface::usernameIsValid(const QString &user) return re.exactMatch(user); } +// TODO move this to Server bool Servatrice_DatabaseInterface::getRequireRegistration() { return settingsCache->value("authentication/regonly", 0).toBool(); } +bool Servatrice_DatabaseInterface::registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active) +{ + if (!checkSql()) + return false; + + QSqlQuery *query = prepareQuery("insert into {prefix}_users " + "(name, realname, gender, password_sha512, email, country, registrationDate, active) " + "values " + "(:userName, :realName, :gender, :password_sha512, :email, :country, UTC_TIMESTAMP(), :active)"); + query->bindValue(":userName", userName); + query->bindValue(":realName", realName); + query->bindValue(":gender", getGenderChar(gender)); + query->bindValue(":password_sha512", passwordSha512); + query->bindValue(":email", emailAddress); + query->bindValue(":country", country); + query->bindValue(":active", active ? 1 : 0); + + if (!execSqlQuery(query)) { + qDebug() << "Failed to insert user: " << query->lastError() << " sql: " << query->lastQuery(); + // TODO handle duplicate insert error + return false; + } + + return true; +} + +QChar Servatrice_DatabaseInterface::getGenderChar(ServerInfo_User_Gender const &gender) +{ + switch (gender) { + case ServerInfo_User_Gender_GenderUnknown: + return QChar('u'); + case ServerInfo_User_Gender_Male: + return QChar('m'); + case ServerInfo_User_Gender_Female: + return QChar('f'); + default: + return QChar('u'); + } +} + AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &banSecondsLeft) { switch (server->getAuthenticationMethod()) { @@ -125,43 +169,8 @@ AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_Prot if (!usernameIsValid(user)) return UsernameInvalid; - QSqlQuery *ipBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.ip_address = :address) and b.ip_address = :address2"); - ipBanQuery->bindValue(":address", static_cast(handler)->getPeerAddress().toString()); - ipBanQuery->bindValue(":address2", static_cast(handler)->getPeerAddress().toString()); - if (!execSqlQuery(ipBanQuery)) { - qDebug("Login denied: SQL error"); - return NotLoggedIn; - } - - if (ipBanQuery->next()) { - const int secondsLeft = -ipBanQuery->value(0).toInt(); - const bool permanentBan = ipBanQuery->value(1).toInt(); - if ((secondsLeft > 0) || permanentBan) { - reasonStr = ipBanQuery->value(2).toString(); - banSecondsLeft = permanentBan ? 0 : secondsLeft; - qDebug("Login denied: banned by address"); - return UserIsBanned; - } - } - - QSqlQuery *nameBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.user_name = :name2) and b.user_name = :name1"); - nameBanQuery->bindValue(":name1", user); - nameBanQuery->bindValue(":name2", user); - if (!execSqlQuery(nameBanQuery)) { - qDebug("Login denied: SQL error"); - return NotLoggedIn; - } - - if (nameBanQuery->next()) { - const int secondsLeft = -nameBanQuery->value(0).toInt(); - const bool permanentBan = nameBanQuery->value(1).toInt(); - if ((secondsLeft > 0) || permanentBan) { - reasonStr = nameBanQuery->value(2).toString(); - banSecondsLeft = permanentBan ? 0 : secondsLeft; - qDebug("Login denied: banned by name"); - return UserIsBanned; - } - } + if (checkUserIsBanned(handler->getAddress(), user, reasonStr, banSecondsLeft)) + return UserIsBanned; QSqlQuery *passwordQuery = prepareQuery("select password_sha512 from {prefix}_users where name = :name and active = 1"); passwordQuery->bindValue(":name", user); @@ -188,6 +197,79 @@ AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_Prot return UnknownUser; } +bool Servatrice_DatabaseInterface::checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining) +{ + if (server->getAuthenticationMethod() != Servatrice::AuthenticationSql) + return false; + + if (!checkSql()) { + qDebug("Failed to check if user is banned. Database invalid."); + return false; + } + + return + checkUserIsIpBanned(ipAddress, banReason, banSecondsRemaining) + || checkUserIsNameBanned(userName, banReason, banSecondsRemaining); + +} + +bool Servatrice_DatabaseInterface::checkUserIsNameBanned(const QString &userName, QString &banReason, int &banSecondsRemaining) +{ + QSqlQuery *nameBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.user_name = :name2) and b.user_name = :name1"); + nameBanQuery->bindValue(":name1", userName); + nameBanQuery->bindValue(":name2", userName); + if (!execSqlQuery(nameBanQuery)) { + qDebug() << "Name ban check failed: SQL error" << nameBanQuery->lastError(); + return false; + } + + if (nameBanQuery->next()) { + const int secondsLeft = -nameBanQuery->value(0).toInt(); + const bool permanentBan = nameBanQuery->value(1).toInt(); + if ((secondsLeft > 0) || permanentBan) { + banReason = nameBanQuery->value(2).toString(); + banSecondsRemaining = permanentBan ? 0 : secondsLeft; + qDebug() << "Username" << userName << "is banned by name"; + return true; + } + } + return false; +} + +bool Servatrice_DatabaseInterface::checkUserIsIpBanned(const QString &ipAddress, QString &banReason, int &banSecondsRemaining) +{ + QSqlQuery *ipBanQuery = prepareQuery( + "select" + " time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute)))," + " b.minutes <=> 0," + " b.visible_reason" + " from {prefix}_bans b" + " where" + " b.time_from = (select max(c.time_from)" + " from {prefix}_bans c" + " where c.ip_address = :address)" + " and b.ip_address = :address2"); + + ipBanQuery->bindValue(":address", ipAddress); + ipBanQuery->bindValue(":address2", ipAddress); + if (!execSqlQuery(ipBanQuery)) { + qDebug() << "IP ban check failed: SQL error." << ipBanQuery->lastError(); + return false; + } + + if (ipBanQuery->next()) { + const int secondsLeft = -ipBanQuery->value(0).toInt(); + const bool permanentBan = ipBanQuery->value(1).toInt(); + if ((secondsLeft > 0) || permanentBan) { + banReason = ipBanQuery->value(2).toString(); + banSecondsRemaining = permanentBan ? 0 : secondsLeft; + qDebug() << "User is banned by address" << ipAddress; + return true; + } + } + return false; +} + bool Servatrice_DatabaseInterface::userExists(const QString &user) { if (server->getAuthenticationMethod() == Servatrice::AuthenticationSql) { @@ -579,4 +661,4 @@ void Servatrice_DatabaseInterface::logMessage(const int senderId, const QString query->bindValue(":target_id", (targetType == MessageTargetChat && targetId < 1) ? QVariant() : targetId); query->bindValue(":target_name", targetName); execSqlQuery(query); -} +} \ No newline at end of file diff --git a/servatrice/src/servatrice_database_interface.h b/servatrice/src/servatrice_database_interface.h index 8850d7be..d352ba7a 100644 --- a/servatrice/src/servatrice_database_interface.h +++ b/servatrice/src/servatrice_database_interface.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "server.h" #include "server_database_interface.h" @@ -18,15 +19,22 @@ private: QHash preparedStatements; Servatrice *server; ServerInfo_User evalUserQueryResult(const QSqlQuery *query, bool complete, bool withId = false); - bool usernameIsValid(const QString &user); + /** Must be called after checkSql and server is known to be in auth mode. */ + bool checkUserIsIpBanned(const QString &ipAddress, QString &banReason, int &banSecondsRemaining); + /** Must be called after checkSql and server is known to be in auth mode. */ + bool checkUserIsNameBanned(QString const &userName, QString &banReason, int &banSecondsRemaining); + QChar getGenderChar(ServerInfo_User_Gender const &gender); protected: + bool usernameIsValid(const QString &user); AuthenticationResult checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &secondsLeft); + bool checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining); public slots: void initDatabase(const QSqlDatabase &_sqlDatabase); public: Servatrice_DatabaseInterface(int _instanceId, Servatrice *_server); ~Servatrice_DatabaseInterface(); - void initDatabase(const QString &type, const QString &hostName, const QString &databaseName, const QString &userName, const QString &password); + bool initDatabase(const QString &type, const QString &hostName, const QString &databaseName, + const QString &userName, const QString &password); bool openDatabase(); bool checkSql(); QSqlQuery * prepareQuery(const QString &queryText); @@ -55,6 +63,7 @@ public: bool userSessionExists(const QString &userName); bool getRequireRegistration(); + bool registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active = false); void logMessage(const int senderId, const QString &senderName, const QString &senderIp, const QString &logMessage, LogMessage_TargetType targetType, const int targetId, const QString &targetName); };