提交 314db4db 编写于 作者: L Longda Feng

code format the observer

上级 ab89b1d5
...@@ -19,14 +19,12 @@ See the Mulan PSL v2 for more details. */ ...@@ -19,14 +19,12 @@ See the Mulan PSL v2 for more details. */
namespace common { namespace common {
template <typename Key, typename Value, template <typename Key, typename Value, typename Hash = std::hash<Key>, typename Pred = std::equal_to<Key>>
typename Hash = std::hash<Key>,
typename Pred = std::equal_to<Key>>
class LruCache { class LruCache {
class ListNode { class ListNode {
public: public:
Key key_; Key key_;
Value value_; Value value_;
ListNode *prev_ = nullptr; ListNode *prev_ = nullptr;
...@@ -39,9 +37,10 @@ class LruCache { ...@@ -39,9 +37,10 @@ class LruCache {
class PListNodeHasher { class PListNodeHasher {
public: public:
size_t operator() (ListNode *node) const { size_t operator()(ListNode *node) const
{
if (node == nullptr) { if (node == nullptr) {
return 0; return 0;
} }
return hasher_(node->key_); return hasher_(node->key_);
} }
...@@ -52,26 +51,28 @@ class LruCache { ...@@ -52,26 +51,28 @@ class LruCache {
class PListNodePredicator { class PListNodePredicator {
public: public:
bool operator() (ListNode * const node1, ListNode * const node2) const { bool operator()(ListNode *const node1, ListNode *const node2) const
{
if (node1 == node2) { if (node1 == node2) {
return true; return true;
} }
if (node1 == nullptr || node2 == nullptr) { if (node1 == nullptr || node2 == nullptr) {
return false; return false;
} }
return pred_(node1->key_, node2->key_); return pred_(node1->key_, node2->key_);
} }
private: private:
Pred pred_; Pred pred_;
}; };
public: public:
LruCache(size_t reserve = 0) LruCache(size_t reserve = 0)
{ {
if (reserve > 0) { if (reserve > 0) {
searcher_.reserve(reserve); searcher_.reserve(reserve);
} }
} }
...@@ -88,7 +89,7 @@ public: ...@@ -88,7 +89,7 @@ public:
searcher_.clear(); searcher_.clear();
lru_front_ = nullptr; lru_front_ = nullptr;
lru_tail_ = nullptr; lru_tail_ = nullptr;
} }
size_t count() const size_t count() const
...@@ -112,10 +113,10 @@ public: ...@@ -112,10 +113,10 @@ public:
{ {
auto iter = searcher_.find((ListNode *)&key); auto iter = searcher_.find((ListNode *)&key);
if (iter != searcher_.end()) { if (iter != searcher_.end()) {
ListNode * ln = *iter; ListNode *ln = *iter;
ln->value_ = value; ln->value_ = value;
lru_touch(ln); lru_touch(ln);
return ; return;
} }
ListNode *ln = new ListNode(key, value); ListNode *ln = new ListNode(key, value);
...@@ -136,7 +137,7 @@ public: ...@@ -136,7 +137,7 @@ public:
value = nullptr; value = nullptr;
} }
void foreach(std::function<bool(const Key &, const Value &)> func) void foreach (std::function<bool(const Key &, const Value &)> func)
{ {
for (ListNode *node = lru_front_; node != nullptr; node = node->next_) { for (ListNode *node = lru_front_; node != nullptr; node = node->next_) {
bool ret = func(node->key_, node->value_); bool ret = func(node->key_, node->value_);
...@@ -165,7 +166,7 @@ private: ...@@ -165,7 +166,7 @@ private:
} }
node->prev_->next_ = node->next_; node->prev_->next_ = node->next_;
if (node->next_ != nullptr) { if (node->next_ != nullptr) {
node->next_->prev_ = node->prev_; node->next_->prev_ = node->prev_;
} else { } else {
...@@ -220,9 +221,9 @@ private: ...@@ -220,9 +221,9 @@ private:
private: private:
using SearchType = std::unordered_set<ListNode *, PListNodeHasher, PListNodePredicator>; using SearchType = std::unordered_set<ListNode *, PListNodeHasher, PListNodePredicator>;
SearchType searcher_; SearchType searcher_;
ListNode * lru_front_ = nullptr; ListNode *lru_front_ = nullptr;
ListNode * lru_tail_ = nullptr; ListNode *lru_tail_ = nullptr;
}; };
} // namespace common } // namespace common
...@@ -22,16 +22,18 @@ class Stmt; ...@@ -22,16 +22,18 @@ class Stmt;
class OptimizeEvent : public common::StageEvent { class OptimizeEvent : public common::StageEvent {
public: public:
OptimizeEvent(SQLStageEvent *sql_event, common::StageEvent *parent_event) OptimizeEvent(SQLStageEvent *sql_event, common::StageEvent *parent_event)
: sql_event_(sql_event), parent_event_(parent_event) : sql_event_(sql_event), parent_event_(parent_event)
{} {}
virtual ~OptimizeEvent() noexcept = default; virtual ~OptimizeEvent() noexcept = default;
SQLStageEvent *sql_event() const { SQLStageEvent *sql_event() const
{
return sql_event_; return sql_event_;
} }
common::StageEvent *parent_event() const { common::StageEvent *parent_event() const
{
return parent_event_; return parent_event_;
} }
...@@ -39,4 +41,3 @@ private: ...@@ -39,4 +41,3 @@ private:
SQLStageEvent *sql_event_ = nullptr; SQLStageEvent *sql_event_ = nullptr;
common::StageEvent *parent_event_ = nullptr; common::StageEvent *parent_event_ = nullptr;
}; };
...@@ -16,8 +16,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,8 +16,7 @@ See the Mulan PSL v2 for more details. */
#include "net/communicator.h" #include "net/communicator.h"
SessionEvent::SessionEvent(Communicator *comm) : communicator_(comm) SessionEvent::SessionEvent(Communicator *comm) : communicator_(comm)
{ {}
}
SessionEvent::~SessionEvent() SessionEvent::~SessionEvent()
{ {
......
...@@ -32,17 +32,29 @@ public: ...@@ -32,17 +32,29 @@ public:
Communicator *get_communicator() const; Communicator *get_communicator() const;
Session *session() const; Session *session() const;
void set_query(const std::string &query) { query_ = query; } void set_query(const std::string &query)
void set_sql_result(SqlResult *result) { sql_result_ = result; } {
const std::string &query() const { return query_; } query_ = query;
SqlResult *sql_result() const { return sql_result_; } }
void set_sql_result(SqlResult *result)
{
sql_result_ = result;
}
const std::string &query() const
{
return query_;
}
SqlResult *sql_result() const
{
return sql_result_;
}
const char *get_response() const; const char *get_response() const;
void set_response(const char *response); void set_response(const char *response);
void set_response(const char *response, int len); void set_response(const char *response, int len);
void set_response(std::string &&response); void set_response(std::string &&response);
int get_response_len() const; int get_response_len() const;
const char *get_request_buf(); // TODO remove me const char *get_request_buf(); // TODO remove me
private: private:
Communicator *communicator_ = nullptr; Communicator *communicator_ = nullptr;
......
...@@ -23,8 +23,7 @@ class SessionEvent; ...@@ -23,8 +23,7 @@ class SessionEvent;
class Stmt; class Stmt;
class Command; class Command;
class SQLStageEvent : public common::StageEvent class SQLStageEvent : public common::StageEvent {
{
public: public:
SQLStageEvent(SessionEvent *event, const std::string &sql); SQLStageEvent(SessionEvent *event, const std::string &sql);
virtual ~SQLStageEvent() noexcept; virtual ~SQLStageEvent() noexcept;
...@@ -34,16 +33,43 @@ public: ...@@ -34,16 +33,43 @@ public:
return session_event_; return session_event_;
} }
const std::string &sql() const { return sql_; } const std::string &sql() const
const std::unique_ptr<Command> &command() const { return command_; } {
Stmt *stmt() const { return stmt_; } return sql_;
std::unique_ptr<PhysicalOperator> &physical_operator() { return operator_; } }
const std::unique_ptr<PhysicalOperator> &physical_operator() const { return operator_; } const std::unique_ptr<Command> &command() const
{
return command_;
}
Stmt *stmt() const
{
return stmt_;
}
std::unique_ptr<PhysicalOperator> &physical_operator()
{
return operator_;
}
const std::unique_ptr<PhysicalOperator> &physical_operator() const
{
return operator_;
}
void set_sql(const char *sql) { sql_ = sql; } void set_sql(const char *sql)
void set_command(std::unique_ptr<Command> cmd) { command_ = std::move(cmd); } {
void set_stmt(Stmt *stmt) { stmt_ = stmt; } sql_ = sql;
void set_operator(std::unique_ptr<PhysicalOperator> oper) { operator_ = std::move(oper); } }
void set_command(std::unique_ptr<Command> cmd)
{
command_ = std::move(cmd);
}
void set_stmt(Stmt *stmt)
{
stmt_ = stmt;
}
void set_operator(std::unique_ptr<PhysicalOperator> oper)
{
operator_ = std::move(oper);
}
private: private:
SessionEvent *session_event_ = nullptr; SessionEvent *session_event_ = nullptr;
...@@ -52,4 +78,3 @@ private: ...@@ -52,4 +78,3 @@ private:
Stmt *stmt_ = nullptr; Stmt *stmt_ = nullptr;
std::unique_ptr<PhysicalOperator> operator_; std::unique_ptr<PhysicalOperator> operator_;
}; };
...@@ -20,7 +20,6 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,7 +20,6 @@ See the Mulan PSL v2 for more details. */
#include "common/io/io.h" #include "common/io/io.h"
#include "session/session.h" #include "session/session.h"
RC Communicator::init(int fd, Session *session, const std::string &addr) RC Communicator::init(int fd, Session *session, const std::string &addr)
{ {
fd_ = fd; fd_ = fd;
...@@ -88,7 +87,6 @@ RC PlainCommunicator::read_event(SessionEvent *&event) ...@@ -88,7 +87,6 @@ RC PlainCommunicator::read_event(SessionEvent *&event)
data_len += read_len; data_len += read_len;
} }
if (data_len > max_packet_size) { if (data_len > max_packet_size) {
LOG_WARN("The length of sql exceeds the limitation %d", max_packet_size); LOG_WARN("The length of sql exceeds the limitation %d", max_packet_size);
return RC::IOERR; return RC::IOERR;
...@@ -107,7 +105,6 @@ RC PlainCommunicator::read_event(SessionEvent *&event) ...@@ -107,7 +105,6 @@ RC PlainCommunicator::read_event(SessionEvent *&event)
return rc; return rc;
} }
RC PlainCommunicator::write_state(SessionEvent *event, bool &need_disconnect) RC PlainCommunicator::write_state(SessionEvent *event, bool &need_disconnect)
{ {
SqlResult *sql_result = event->sql_result(); SqlResult *sql_result = event->sql_result();
...@@ -137,12 +134,12 @@ RC PlainCommunicator::write_state(SessionEvent *event, bool &need_disconnect) ...@@ -137,12 +134,12 @@ RC PlainCommunicator::write_state(SessionEvent *event, bool &need_disconnect)
RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect) RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
{ {
need_disconnect = true; need_disconnect = true;
const char message_terminate = '\0'; const char message_terminate = '\0';
SqlResult *sql_result = event->sql_result(); SqlResult *sql_result = event->sql_result();
if (nullptr == sql_result) { if (nullptr == sql_result) {
const char *response = event->get_response(); const char *response = event->get_response();
int len = event->get_response_len(); int len = event->get_response_len();
...@@ -249,7 +246,7 @@ RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect) ...@@ -249,7 +246,7 @@ RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
if (rc == RC::RECORD_EOF) { if (rc == RC::RECORD_EOF) {
rc = RC::SUCCESS; rc = RC::SUCCESS;
} }
if (cell_num == 0) { if (cell_num == 0) {
// 除了select之外,其它的消息通常不会通过operator来返回结果,表头和行数据都是空的 // 除了select之外,其它的消息通常不会通过operator来返回结果,表头和行数据都是空的
// 这里针对这种情况做特殊处理,当表头和行数据都是空的时候,就返回处理的结果 // 这里针对这种情况做特殊处理,当表头和行数据都是空的时候,就返回处理的结果
...@@ -271,7 +268,6 @@ RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect) ...@@ -271,7 +268,6 @@ RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
///////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////
Communicator *CommunicatorFactory::create(CommunicateProtocol protocol) Communicator *CommunicatorFactory::create(CommunicateProtocol protocol)
{ {
switch (protocol) { switch (protocol) {
......
...@@ -28,10 +28,10 @@ class Session; ...@@ -28,10 +28,10 @@ class Session;
* 在listener接收到一个新的连接(参考 server.cpp::accept), 就创建一个Communicator对象。 * 在listener接收到一个新的连接(参考 server.cpp::accept), 就创建一个Communicator对象。
* 并调用init进行初始化。 * 并调用init进行初始化。
* 在server中监听到某个连接有新的消息,就通过Communicator::read_event接收消息。 * 在server中监听到某个连接有新的消息,就通过Communicator::read_event接收消息。
*/ */
class Communicator { class Communicator {
public: public:
virtual ~Communicator(); virtual ~Communicator();
/** /**
...@@ -57,18 +57,27 @@ public: ...@@ -57,18 +57,27 @@ public:
/** /**
* 关联的会话信息 * 关联的会话信息
*/ */
Session *session() const { return session_; } Session *session() const
{
return session_;
}
/** /**
* libevent使用的数据,参考server.cpp * libevent使用的数据,参考server.cpp
*/ */
struct event &read_event() { return read_event_; } struct event &read_event()
{
return read_event_;
}
/** /**
* 对端地址 * 对端地址
* 如果是unix socket,可能没有意义 * 如果是unix socket,可能没有意义
*/ */
const char *addr() const { return addr_.c_str(); } const char *addr() const
{
return addr_.c_str();
}
protected: protected:
Session *session_ = nullptr; Session *session_ = nullptr;
...@@ -82,26 +91,23 @@ protected: ...@@ -82,26 +91,23 @@ protected:
* 使用简单的文本通讯协议,每个消息使用'\0'结尾 * 使用简单的文本通讯协议,每个消息使用'\0'结尾
*/ */
class PlainCommunicator : public Communicator { class PlainCommunicator : public Communicator {
public: public:
RC read_event(SessionEvent *&event) override; RC read_event(SessionEvent *&event) override;
RC write_result(SessionEvent *event, bool &need_disconnect) override; RC write_result(SessionEvent *event, bool &need_disconnect) override;
private: private:
RC write_state(SessionEvent *event, bool &need_disconnect); RC write_state(SessionEvent *event, bool &need_disconnect);
}; };
/** /**
* 当前支持的通讯协议 * 当前支持的通讯协议
*/ */
enum class CommunicateProtocol enum class CommunicateProtocol {
{
PLAIN, //! 以'\0'结尾的协议 PLAIN, //! 以'\0'结尾的协议
MYSQL, //! mysql通讯协议。具体实现参考 MysqlCommunicator MYSQL, //! mysql通讯协议。具体实现参考 MysqlCommunicator
}; };
class CommunicatorFactory class CommunicatorFactory {
{ public:
public:
Communicator *create(CommunicateProtocol protocol); Communicator *create(CommunicateProtocol protocol);
}; };
...@@ -23,31 +23,30 @@ See the Mulan PSL v2 for more details. */ ...@@ -23,31 +23,30 @@ See the Mulan PSL v2 for more details. */
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
// the flags below are negotiate by handshake packet // the flags below are negotiate by handshake packet
const uint32_t CLIENT_PROTOCOL_41 = 512; const uint32_t CLIENT_PROTOCOL_41 = 512;
//const uint32_t CLIENT_INTERACTIVE = 1024; // This is an interactive client // const uint32_t CLIENT_INTERACTIVE = 1024; // This is an interactive client
const uint32_t CLIENT_TRANSACTIONS = 8192; // Client knows about transactions. const uint32_t CLIENT_TRANSACTIONS = 8192; // Client knows about transactions.
const uint32_t CLIENT_SESSION_TRACK = (1UL << 23); // Capable of handling server state change information const uint32_t CLIENT_SESSION_TRACK = (1UL << 23); // Capable of handling server state change information
const uint32_t CLIENT_DEPRECATE_EOF = (1UL << 24); // Client no longer needs EOF_Packet and will use OK_Packet instead const uint32_t CLIENT_DEPRECATE_EOF = (1UL << 24); // Client no longer needs EOF_Packet and will use OK_Packet instead
const uint32_t CLIENT_OPTIONAL_RESULTSET_METADATA = (1UL << 25); // The client can handle optional metadata information in the resultset. const uint32_t CLIENT_OPTIONAL_RESULTSET_METADATA =
(1UL << 25); // The client can handle optional metadata information in the resultset.
// Support optional extension for query parameters into the COM_QUERY and COM_STMT_EXECUTE packets. // Support optional extension for query parameters into the COM_QUERY and COM_STMT_EXECUTE packets.
//const uint32_t CLIENT_QUERY_ATTRIBUTES = (1UL << 27); // const uint32_t CLIENT_QUERY_ATTRIBUTES = (1UL << 27);
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
// Column Definition Flags // Column Definition Flags
//const uint32_t NOT_NULL_FLAG = 1; // const uint32_t NOT_NULL_FLAG = 1;
//const uint32_t PRI_KEY_FLAG = 2; // const uint32_t PRI_KEY_FLAG = 2;
//const uint32_t UNIQUE_KEY_FLAG = 4; // const uint32_t UNIQUE_KEY_FLAG = 4;
//const uint32_t MULTIPLE_KEY_FLAG = 8; // const uint32_t MULTIPLE_KEY_FLAG = 8;
//const uint32_t NUM_FLAG = 32768; // Field is num (for clients) // const uint32_t NUM_FLAG = 32768; // Field is num (for clients)
//const uint32_t PART_KEY_FLAG = 16384; // Intern; Part of some key. // const uint32_t PART_KEY_FLAG = 16384; // Intern; Part of some key.
enum ResultSetMetaData enum ResultSetMetaData {
{
RESULTSET_METADATA_NONE = 0, RESULTSET_METADATA_NONE = 0,
RESULTSET_METADATA_FULL = 1, RESULTSET_METADATA_FULL = 1,
}; };
/** /**
Column types for MySQL Column types for MySQL
*/ */
...@@ -185,13 +184,12 @@ int store_lenenc_string(char *buf, const char *s) ...@@ -185,13 +184,12 @@ int store_lenenc_string(char *buf, const char *s)
* https://mariadb.com/kb/en/0-packet/ * https://mariadb.com/kb/en/0-packet/
*/ */
struct PacketHeader { struct PacketHeader {
int32_t payload_length:24; //! 当前packet的除掉头的长度 int32_t payload_length : 24; //! 当前packet的除掉头的长度
int8_t sequence_id = 0; //! 当前packet在当前处理过程中是第几个包 int8_t sequence_id = 0; //! 当前packet在当前处理过程中是第几个包
}; };
class BasePacket class BasePacket {
{ public:
public:
PacketHeader packet_header; PacketHeader packet_header;
BasePacket(int8_t sequence = 0) BasePacket(int8_t sequence = 0)
...@@ -209,19 +207,19 @@ public: ...@@ -209,19 +207,19 @@ public:
* 这个包会交互capability与用户名密码 * 这个包会交互capability与用户名密码
* https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html * https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
*/ */
struct HandshakeV10 : public BasePacket struct HandshakeV10 : public BasePacket {
{ int8_t protocol = 10;
int8_t protocol = 10; char server_version[7] = "5.7.25";
char server_version[7] = "5.7.25"; int32_t thread_id = 21501807; // conn id
int32_t thread_id = 21501807; // conn id char auth_plugin_data_part_1[9] =
char auth_plugin_data_part_1[9] = "12345678"; // first 8 bytes of the plugin provided data (scramble) // and the filler "12345678"; // first 8 bytes of the plugin provided data (scramble) // and the filler
int16_t capability_flags_1 = 0xF7DF; // The lower 2 bytes of the Capabilities Flags int16_t capability_flags_1 = 0xF7DF; // The lower 2 bytes of the Capabilities Flags
int8_t character_set = 83; int8_t character_set = 83;
int16_t status_flags = 0; int16_t status_flags = 0;
int16_t capability_flags_2 = 0x0000; int16_t capability_flags_2 = 0x0000;
int8_t auth_plugin_data_len = 0; int8_t auth_plugin_data_len = 0;
char reserved[10] = {0}; char reserved[10] = {0};
char auth_plugin_data_part_2[13] = "bbbbbbbbbbbb"; char auth_plugin_data_part_2[13] = "bbbbbbbbbbbb";
HandshakeV10(int8_t sequence = 0) : BasePacket(sequence) HandshakeV10(int8_t sequence = 0) : BasePacket(sequence)
{} {}
...@@ -256,19 +254,18 @@ struct HandshakeV10 : public BasePacket ...@@ -256,19 +254,18 @@ struct HandshakeV10 : public BasePacket
store_int3(buf, payload_length); store_int3(buf, payload_length);
net_packet.resize(pos); net_packet.resize(pos);
LOG_TRACE("encode handshake packet with payload length=%d", payload_length); LOG_TRACE("encode handshake packet with payload length=%d", payload_length);
return RC::SUCCESS; return RC::SUCCESS;
} }
}; };
struct OkPacket : public BasePacket struct OkPacket : public BasePacket {
{ int8_t header = 0; // 0x00 for ok and 0xFE for EOF
int8_t header = 0; // 0x00 for ok and 0xFE for EOF int32_t affected_rows = 0;
int32_t affected_rows = 0; int32_t last_insert_id = 0;
int32_t last_insert_id = 0; int16_t status_flags = 0x22;
int16_t status_flags = 0x22; int16_t warnings = 0;
int16_t warnings = 0; std::string info; // human readable status information
std::string info; // human readable status information
OkPacket(int8_t sequence = 0) : BasePacket(sequence) OkPacket(int8_t sequence = 0) : BasePacket(sequence)
{} {}
...@@ -310,11 +307,10 @@ struct OkPacket : public BasePacket ...@@ -310,11 +307,10 @@ struct OkPacket : public BasePacket
} }
}; };
struct EofPacket : public BasePacket struct EofPacket : public BasePacket {
{ int8_t header = 0xFE;
int8_t header = 0xFE; int16_t warnings = 0;
int16_t warnings = 0; int16_t status_flags = 0x22;
int16_t status_flags = 0x22;
EofPacket(int8_t sequence = 0) : BasePacket(sequence) EofPacket(int8_t sequence = 0) : BasePacket(sequence)
{} {}
...@@ -350,13 +346,12 @@ struct EofPacket : public BasePacket ...@@ -350,13 +346,12 @@ struct EofPacket : public BasePacket
} }
}; };
struct ErrPacket : public BasePacket struct ErrPacket : public BasePacket {
{ int8_t header = 0xFF;
int8_t header = 0xFF; int16_t error_code = 0;
int16_t error_code = 0; char sql_state_marker[1] = {'#'};
char sql_state_marker[1] = {'#'}; std::string sql_state{"HY000"};
std::string sql_state{"HY000"}; std::string error_message;
std::string error_message;
ErrPacket(int8_t sequence = 0) : BasePacket(sequence) ErrPacket(int8_t sequence = 0) : BasePacket(sequence)
{} {}
...@@ -395,11 +390,10 @@ struct ErrPacket : public BasePacket ...@@ -395,11 +390,10 @@ struct ErrPacket : public BasePacket
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase.html // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase.html
// https://mariadb.com/kb/en/2-text-protocol/ // https://mariadb.com/kb/en/2-text-protocol/
struct QueryPacket struct QueryPacket {
{
PacketHeader packet_header; PacketHeader packet_header;
int8_t command; // 0x03: COM_QUERY int8_t command; // 0x03: COM_QUERY
std::string query; // the text of the SQL query to execute std::string query; // the text of the SQL query to execute
}; };
/** /**
...@@ -474,26 +468,31 @@ RC MysqlCommunicator::read_event(SessionEvent *&event) ...@@ -474,26 +468,31 @@ RC MysqlCommunicator::read_event(SessionEvent *&event)
PacketHeader packet_header; PacketHeader packet_header;
int ret = common::readn(fd_, &packet_header, sizeof(packet_header)); int ret = common::readn(fd_, &packet_header, sizeof(packet_header));
if (ret != 0) { if (ret != 0) {
LOG_WARN("failed to read packet header. length=%d, addr=%s. error=%s", sizeof(packet_header), addr_.c_str(), strerror(errno)); LOG_WARN("failed to read packet header. length=%d, addr=%s. error=%s",
sizeof(packet_header),
addr_.c_str(),
strerror(errno));
return RC::IOERR; return RC::IOERR;
} }
LOG_TRACE("read packet header. length=%d, sequence_id=%d", sizeof(packet_header), packet_header.sequence_id); LOG_TRACE("read packet header. length=%d, sequence_id=%d", sizeof(packet_header), packet_header.sequence_id);
sequence_id_ = packet_header.sequence_id + 1; sequence_id_ = packet_header.sequence_id + 1;
std::vector<char> buf(packet_header.payload_length); std::vector<char> buf(packet_header.payload_length);
ret = common::readn(fd_, buf.data(), packet_header.payload_length); ret = common::readn(fd_, buf.data(), packet_header.payload_length);
if (ret != 0) { if (ret != 0) {
LOG_WARN("failed to read packet payload. length=%d, addr=%s, error=%s", LOG_WARN("failed to read packet payload. length=%d, addr=%s, error=%s",
packet_header.payload_length, addr_.c_str(), strerror(errno)); packet_header.payload_length,
addr_.c_str(),
strerror(errno));
return RC::IOERR; return RC::IOERR;
} }
LOG_TRACE("read packet payload length=%d", packet_header.payload_length); LOG_TRACE("read packet payload length=%d", packet_header.payload_length);
event = nullptr; event = nullptr;
if (!authed_) { if (!authed_) {
uint32_t client_flag = *(uint32_t*)buf.data(); // TODO should use decode (little endian as default) uint32_t client_flag = *(uint32_t *)buf.data(); // TODO should use decode (little endian as default)
LOG_INFO("client handshake response with capabilities flag=%d", client_flag); LOG_INFO("client handshake response with capabilities flag=%d", client_flag);
client_capabilities_flag_ = client_flag; client_capabilities_flag_ = client_flag;
// send ok packet and return // send ok packet and return
...@@ -510,7 +509,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event) ...@@ -510,7 +509,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event)
int8_t command_type = buf[0]; int8_t command_type = buf[0];
LOG_TRACE("recv command from client =%d", command_type); LOG_TRACE("recv command from client =%d", command_type);
if (command_type == 0x03) { // COM_QUERY if (command_type == 0x03) { // COM_QUERY
QueryPacket query_packet; QueryPacket query_packet;
rc = decode_query_packet(buf, query_packet); rc = decode_query_packet(buf, query_packet);
...@@ -541,7 +540,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event) ...@@ -541,7 +540,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event)
RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect) RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect)
{ {
SqlResult *sql_result = event->sql_result(); SqlResult *sql_result = event->sql_result();
const int buf_size = 2048; const int buf_size = 2048;
char *buf = new char[buf_size]; char *buf = new char[buf_size];
const std::string &state_string = sql_result->state_string(); const std::string &state_string = sql_result->state_string();
...@@ -580,14 +579,14 @@ RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect) ...@@ -580,14 +579,14 @@ RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect)
RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
need_disconnect = true; need_disconnect = true;
SqlResult *sql_result = event->sql_result(); SqlResult *sql_result = event->sql_result();
if (nullptr == sql_result) { if (nullptr == sql_result) {
const char *response = event->get_response(); const char *response = event->get_response();
int len = event->get_response_len(); int len = event->get_response_len();
OkPacket ok_packet;// TODO if error occurs, we should send an error packet to client OkPacket ok_packet; // TODO if error occurs, we should send an error packet to client
ok_packet.info.assign(response, len); ok_packet.info.assign(response, len);
rc = send_packet(ok_packet); rc = send_packet(ok_packet);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
...@@ -598,7 +597,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) ...@@ -598,7 +597,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
need_disconnect = false; need_disconnect = false;
} else { } else {
if (RC::SUCCESS != sql_result->return_code() || !sql_result->has_operator()) { if (RC::SUCCESS != sql_result->return_code() || !sql_result->has_operator()) {
return write_state(event, need_disconnect); return write_state(event, need_disconnect);
} }
// send result set // send result set
...@@ -624,7 +623,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) ...@@ -624,7 +623,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect)
rc = send_result_rows(sql_result, cell_num == 0, need_disconnect); rc = send_result_rows(sql_result, cell_num == 0, need_disconnect);
} }
return rc; return rc;
} }
...@@ -689,7 +688,7 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d ...@@ -689,7 +688,7 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d
net_packet.resize(pos); net_packet.resize(pos);
int ret = common::writen(fd_, net_packet.data(), net_packet.size()); int ret = common::writen(fd_, net_packet.data(), net_packet.size());
if (ret != 0){ if (ret != 0) {
LOG_WARN("failed to send column count to client. addr=%s, error=%s", addr(), strerror(errno)); LOG_WARN("failed to send column count to client. addr=%s, error=%s", addr(), strerror(errno));
need_disconnect = true; need_disconnect = true;
return RC::IOERR; return RC::IOERR;
...@@ -705,12 +704,12 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d ...@@ -705,12 +704,12 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d
pos += 1; pos += 1;
const TupleCellSpec &spec = tuple_schema.cell_at(i); const TupleCellSpec &spec = tuple_schema.cell_at(i);
const char *catalog = "def"; // The catalog used. Currently always "def" const char *catalog = "def"; // The catalog used. Currently always "def"
const char *schema = "sys"; // schema name const char *schema = "sys"; // schema name
const char *table = spec.table_name(); const char *table = spec.table_name();
const char *org_table = spec.table_name(); const char *org_table = spec.table_name();
const char *name = spec.alias(); const char *name = spec.alias();
//const char *org_name = spec.field_name(); // const char *org_name = spec.field_name();
const char *org_name = spec.alias(); const char *org_name = spec.alias();
int fixed_len_fields = 0x0c; int fixed_len_fields = 0x0c;
int character_set = 33; int character_set = 33;
...@@ -736,7 +735,7 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d ...@@ -736,7 +735,7 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d
pos += 2; pos += 2;
store_int1(buf + pos, decimals); store_int1(buf + pos, decimals);
pos += 1; pos += 1;
store_int2(buf + pos, 0); // 按照mariadb的文档描述,最后还有一个unused字段int<2>,不过mysql的文档没有给出这样的描述 store_int2(buf + pos, 0); // 按照mariadb的文档描述,最后还有一个unused字段int<2>,不过mysql的文档没有给出这样的描述
pos += 2; pos += 2;
payload_length = pos - 4; payload_length = pos - 4;
...@@ -779,7 +778,7 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def ...@@ -779,7 +778,7 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
std::vector<char> packet; std::vector<char> packet;
packet.resize(4 * 1024 * 1024); // TODO warning: length cannot be fix packet.resize(4 * 1024 * 1024); // TODO warning: length cannot be fix
int affected_rows = 0; int affected_rows = 0;
Tuple *tuple = nullptr; Tuple *tuple = nullptr;
...@@ -787,7 +786,7 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def ...@@ -787,7 +786,7 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def
assert(tuple != nullptr); assert(tuple != nullptr);
affected_rows++; affected_rows++;
const int cell_num = tuple->cell_num(); const int cell_num = tuple->cell_num();
if (cell_num == 0) { if (cell_num == 0) {
continue; continue;
...@@ -807,12 +806,12 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def ...@@ -807,12 +806,12 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def
rc = tuple->cell_at(i, tuple_cell); rc = tuple->cell_at(i, tuple_cell);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
sql_result->set_return_code(rc); sql_result->set_return_code(rc);
break; // TODO send error packet break; // TODO send error packet
} }
std::stringstream ss; std::stringstream ss;
tuple_cell.to_string(ss); tuple_cell.to_string(ss);
pos += store_lenenc_string(buf + pos, ss.str().c_str()); pos += store_lenenc_string(buf + pos, ss.str().c_str());
} }
int payload_length = pos - 4; int payload_length = pos - 4;
......
...@@ -27,7 +27,6 @@ class BasePacket; ...@@ -27,7 +27,6 @@ class BasePacket;
*/ */
class MysqlCommunicator : public Communicator { class MysqlCommunicator : public Communicator {
public: public:
/** /**
* 连接刚开始建立时,进行一些初始化 * 连接刚开始建立时,进行一些初始化
* 参考MySQL或MariaDB的手册,服务端要首先向客户端发送一个握手包,等客户端回复后, * 参考MySQL或MariaDB的手册,服务端要首先向客户端发送一个握手包,等客户端回复后,
......
...@@ -202,8 +202,8 @@ void Server::accept(int fd, short ev, void *arg) ...@@ -202,8 +202,8 @@ void Server::accept(int fd, short ev, void *arg)
ret = event_base_set(instance->event_base_, &communicator->read_event()); ret = event_base_set(instance->event_base_, &communicator->read_event());
if (ret < 0) { if (ret < 0) {
LOG_ERROR("Failed to do event_base_set for read event of %s into libevent, %s", LOG_ERROR(
communicator->addr(), strerror(errno)); "Failed to do event_base_set for read event of %s into libevent, %s", communicator->addr(), strerror(errno));
delete communicator; delete communicator;
return; return;
} }
......
...@@ -124,7 +124,7 @@ void SessionStage::callback_event(StageEvent *event, CallbackContext *context) ...@@ -124,7 +124,7 @@ void SessionStage::callback_event(StageEvent *event, CallbackContext *context)
RC rc = communicator->write_result(sev, need_disconnect); RC rc = communicator->write_result(sev, need_disconnect);
LOG_INFO("write result return %s", strrc(rc)); LOG_INFO("write result return %s", strrc(rc));
if (need_disconnect) { if (need_disconnect) {
Server::close_connection(communicator); Server::close_connection(communicator);
} }
LOG_TRACE("Exit\n"); LOG_TRACE("Exit\n");
......
...@@ -39,7 +39,7 @@ protected: ...@@ -39,7 +39,7 @@ protected:
RC handle_request(common::StageEvent *event); RC handle_request(common::StageEvent *event);
RC handle_request_with_physical_operator(SQLStageEvent *sql_event); RC handle_request_with_physical_operator(SQLStageEvent *sql_event);
RC do_help(SQLStageEvent *session_event); RC do_help(SQLStageEvent *session_event);
RC do_create_table(SQLStageEvent *sql_event); RC do_create_table(SQLStageEvent *sql_event);
RC do_create_index(SQLStageEvent *sql_event); RC do_create_index(SQLStageEvent *sql_event);
......
...@@ -24,18 +24,38 @@ class SqlResult { ...@@ -24,18 +24,38 @@ class SqlResult {
public: public:
SqlResult() = default; SqlResult() = default;
~SqlResult() ~SqlResult()
{}
void set_tuple_schema(const TupleSchema &schema);
void set_return_code(RC rc)
{
return_code_ = rc;
}
void set_state_string(const std::string &state_string)
{ {
state_string_ = state_string;
} }
void set_tuple_schema(const TupleSchema &schema); void set_operator(std::unique_ptr<PhysicalOperator> oper)
void set_return_code(RC rc) { return_code_ = rc; } {
void set_state_string(const std::string &state_string) { state_string_ = state_string; } operator_ = std::move(oper);
}
void set_operator(std::unique_ptr<PhysicalOperator> oper) { operator_ = std::move(oper); } bool has_operator() const
bool has_operator() const { return operator_ != nullptr; } {
const TupleSchema &tuple_schema() const { return tuple_schema_; } return operator_ != nullptr;
RC return_code() const { return return_code_; } }
const std::string &state_string() const { return state_string_; } const TupleSchema &tuple_schema() const
{
return tuple_schema_;
}
RC return_code() const
{
return return_code_;
}
const std::string &state_string() const
{
return state_string_;
}
RC open(); RC open();
RC close(); RC close();
......
...@@ -15,13 +15,12 @@ See the Mulan PSL v2 for more details. */ ...@@ -15,13 +15,12 @@ See the Mulan PSL v2 for more details. */
#include "sql/expr/expression.h" #include "sql/expr/expression.h"
#include "sql/expr/tuple.h" #include "sql/expr/tuple.h"
RC FieldExpr::get_value(const Tuple &tuple, TupleCell &cell) const RC FieldExpr::get_value(const Tuple &tuple, TupleCell &cell) const
{ {
return tuple.find_cell(TupleCellSpec(table_name(), field_name()), cell); return tuple.find_cell(TupleCellSpec(table_name(), field_name()), cell);
} }
RC ValueExpr::get_value(const Tuple &tuple, TupleCell & cell) const RC ValueExpr::get_value(const Tuple &tuple, TupleCell &cell) const
{ {
cell = tuple_cell_; cell = tuple_cell_;
return RC::SUCCESS; return RC::SUCCESS;
...@@ -33,8 +32,7 @@ CastExpr::CastExpr(std::unique_ptr<Expression> child, AttrType cast_type) ...@@ -33,8 +32,7 @@ CastExpr::CastExpr(std::unique_ptr<Expression> child, AttrType cast_type)
{} {}
CastExpr::~CastExpr() CastExpr::~CastExpr()
{ {}
}
RC CastExpr::get_value(const Tuple &tuple, TupleCell &cell) const RC CastExpr::get_value(const Tuple &tuple, TupleCell &cell) const
{ {
...@@ -67,8 +65,7 @@ ComparisonExpr::ComparisonExpr(CompOp comp, std::unique_ptr<Expression> left, st ...@@ -67,8 +65,7 @@ ComparisonExpr::ComparisonExpr(CompOp comp, std::unique_ptr<Expression> left, st
{} {}
ComparisonExpr::~ComparisonExpr() ComparisonExpr::~ComparisonExpr()
{ {}
}
RC ComparisonExpr::compare_tuple_cell(const TupleCell &left, const TupleCell &right, bool &value) const RC ComparisonExpr::compare_tuple_cell(const TupleCell &left, const TupleCell &right, bool &value) const
{ {
...@@ -99,7 +96,7 @@ RC ComparisonExpr::compare_tuple_cell(const TupleCell &left, const TupleCell &ri ...@@ -99,7 +96,7 @@ RC ComparisonExpr::compare_tuple_cell(const TupleCell &left, const TupleCell &ri
rc = RC::GENERIC_ERROR; rc = RC::GENERIC_ERROR;
} break; } break;
} }
return rc; return rc;
} }
...@@ -128,7 +125,7 @@ RC ComparisonExpr::get_value(const Tuple &tuple, TupleCell &cell) const ...@@ -128,7 +125,7 @@ RC ComparisonExpr::get_value(const Tuple &tuple, TupleCell &cell) const
{ {
TupleCell left_cell; TupleCell left_cell;
TupleCell right_cell; TupleCell right_cell;
RC rc = left_->get_value(tuple, left_cell); RC rc = left_->get_value(tuple, left_cell);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("failed to get value of left expression. rc=%s", strrc(rc)); LOG_WARN("failed to get value of left expression. rc=%s", strrc(rc));
...@@ -151,8 +148,7 @@ RC ComparisonExpr::get_value(const Tuple &tuple, TupleCell &cell) const ...@@ -151,8 +148,7 @@ RC ComparisonExpr::get_value(const Tuple &tuple, TupleCell &cell) const
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
ConjunctionExpr::ConjunctionExpr(Type type, std::vector<std::unique_ptr<Expression>> &children) ConjunctionExpr::ConjunctionExpr(Type type, std::vector<std::unique_ptr<Expression>> &children)
: conjunction_type_(type), children_(std::move(children)) : conjunction_type_(type), children_(std::move(children))
{ {}
}
RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const
{ {
...@@ -161,7 +157,7 @@ RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const ...@@ -161,7 +157,7 @@ RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const
cell.set_boolean(true); cell.set_boolean(true);
return rc; return rc;
} }
TupleCell tmp_cell; TupleCell tmp_cell;
for (const std::unique_ptr<Expression> &expr : children_) { for (const std::unique_ptr<Expression> &expr : children_) {
rc = expr->get_value(tuple, tmp_cell); rc = expr->get_value(tuple, tmp_cell);
...@@ -170,8 +166,7 @@ RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const ...@@ -170,8 +166,7 @@ RC ConjunctionExpr::get_value(const Tuple &tuple, TupleCell &cell) const
return rc; return rc;
} }
bool value = tmp_cell.get_boolean(); bool value = tmp_cell.get_boolean();
if ((conjunction_type_ == Type::AND && !value) if ((conjunction_type_ == Type::AND && !value) || (conjunction_type_ == Type::OR && value)) {
|| (conjunction_type_ == Type::OR && value)) {
cell.set_boolean(value); cell.set_boolean(value);
return rc; return rc;
} }
......
...@@ -41,9 +41,8 @@ enum class ExprType { ...@@ -41,9 +41,8 @@ enum class ExprType {
* 才能计算出来真实的值。但是有些表达式可能就表示某一个固定的 * 才能计算出来真实的值。但是有些表达式可能就表示某一个固定的
* 值,比如ValueExpr。 * 值,比如ValueExpr。
*/ */
class Expression class Expression {
{ public:
public:
Expression() = default; Expression() = default;
virtual ~Expression() = default; virtual ~Expression() = default;
...@@ -61,11 +60,10 @@ public: ...@@ -61,11 +60,10 @@ public:
/** /**
* 表达式值的类型 * 表达式值的类型
*/ */
virtual AttrType value_type() const = 0; virtual AttrType value_type() const = 0;
}; };
class FieldExpr : public Expression class FieldExpr : public Expression {
{
public: public:
FieldExpr() = default; FieldExpr() = default;
FieldExpr(const Table *table, const FieldMeta *field) : field_(table, field) FieldExpr(const Table *table, const FieldMeta *field) : field_(table, field)
...@@ -105,12 +103,12 @@ public: ...@@ -105,12 +103,12 @@ public:
} }
RC get_value(const Tuple &tuple, TupleCell &cell) const override; RC get_value(const Tuple &tuple, TupleCell &cell) const override;
private: private:
Field field_; Field field_;
}; };
class ValueExpr : public Expression class ValueExpr : public Expression {
{
public: public:
ValueExpr() = default; ValueExpr() = default;
ValueExpr(const Value &value) ValueExpr(const Value &value)
...@@ -138,7 +136,7 @@ public: ...@@ -138,7 +136,7 @@ public:
virtual ~ValueExpr() = default; virtual ~ValueExpr() = default;
RC get_value(const Tuple &tuple, TupleCell & cell) const override; RC get_value(const Tuple &tuple, TupleCell &cell) const override;
ExprType type() const override ExprType type() const override
{ {
...@@ -150,11 +148,13 @@ public: ...@@ -150,11 +148,13 @@ public:
return tuple_cell_.attr_type(); return tuple_cell_.attr_type();
} }
void get_tuple_cell(TupleCell &cell) const { void get_tuple_cell(TupleCell &cell) const
{
cell = tuple_cell_; cell = tuple_cell_;
} }
const TupleCell &get_tuple_cell() const { const TupleCell &get_tuple_cell() const
{
return tuple_cell_; return tuple_cell_;
} }
...@@ -162,38 +162,58 @@ private: ...@@ -162,38 +162,58 @@ private:
TupleCell tuple_cell_; TupleCell tuple_cell_;
}; };
class CastExpr : public Expression class CastExpr : public Expression {
{ public:
public:
CastExpr(std::unique_ptr<Expression> child, AttrType cast_type); CastExpr(std::unique_ptr<Expression> child, AttrType cast_type);
virtual ~CastExpr(); virtual ~CastExpr();
ExprType type() const override { return ExprType::CAST; } ExprType type() const override
{
return ExprType::CAST;
}
RC get_value(const Tuple &tuple, TupleCell &cell) const override; RC get_value(const Tuple &tuple, TupleCell &cell) const override;
AttrType value_type() const override AttrType value_type() const override
{ {
return cast_type_; return cast_type_;
} }
std::unique_ptr<Expression> &child() { return child_; } std::unique_ptr<Expression> &child()
{
return child_;
}
private: private:
std::unique_ptr<Expression> child_; std::unique_ptr<Expression> child_;
AttrType cast_type_; AttrType cast_type_;
}; };
class ComparisonExpr : public Expression class ComparisonExpr : public Expression {
{
public: public:
ComparisonExpr(CompOp comp, std::unique_ptr<Expression> left, std::unique_ptr<Expression> right); ComparisonExpr(CompOp comp, std::unique_ptr<Expression> left, std::unique_ptr<Expression> right);
virtual ~ComparisonExpr(); virtual ~ComparisonExpr();
ExprType type() const override { return ExprType::COMPARISON; } ExprType type() const override
{
return ExprType::COMPARISON;
}
RC get_value(const Tuple &tuple, TupleCell &cell) const override; RC get_value(const Tuple &tuple, TupleCell &cell) const override;
AttrType value_type() const override { return BOOLEANS; } AttrType value_type() const override
{
return BOOLEANS;
}
CompOp comp() const { return comp_; } CompOp comp() const
std::unique_ptr<Expression> &left() { return left_; } {
std::unique_ptr<Expression> &right() { return right_; } return comp_;
}
std::unique_ptr<Expression> &left()
{
return left_;
}
std::unique_ptr<Expression> &right()
{
return right_;
}
/** /**
* 尝试在没有tuple的情况下获取当前表达式的值 * 尝试在没有tuple的情况下获取当前表达式的值
...@@ -217,27 +237,38 @@ private: ...@@ -217,27 +237,38 @@ private:
* 多个表达式使用同一种关系(AND或OR)来联结 * 多个表达式使用同一种关系(AND或OR)来联结
* 当前miniob仅有AND操作 * 当前miniob仅有AND操作
*/ */
class ConjunctionExpr : public Expression class ConjunctionExpr : public Expression {
{
public: public:
enum class Type enum class Type {
{
AND, AND,
OR, OR,
}; };
public: public:
ConjunctionExpr(Type type, std::vector<std::unique_ptr<Expression>> &children); ConjunctionExpr(Type type, std::vector<std::unique_ptr<Expression>> &children);
virtual ~ConjunctionExpr() = default; virtual ~ConjunctionExpr() = default;
ExprType type() const override { return ExprType::CONJUNCTION; } ExprType type() const override
AttrType value_type() const override { return BOOLEANS; } {
return ExprType::CONJUNCTION;
}
AttrType value_type() const override
{
return BOOLEANS;
}
RC get_value(const Tuple &tuple, TupleCell &cell) const override; RC get_value(const Tuple &tuple, TupleCell &cell) const override;
Type conjunction_type() const { return conjunction_type_; } Type conjunction_type() const
{
return conjunction_type_;
}
std::vector<std::unique_ptr<Expression>> &children()
{
return children_;
}
std::vector<std::unique_ptr<Expression>> &children() { return children_; }
private: private:
Type conjunction_type_; Type conjunction_type_;
std::vector<std::unique_ptr<Expression>> children_; std::vector<std::unique_ptr<Expression>> children_;
}; };
...@@ -25,32 +25,44 @@ See the Mulan PSL v2 for more details. */ ...@@ -25,32 +25,44 @@ See the Mulan PSL v2 for more details. */
class Table; class Table;
class TupleSchema class TupleSchema {
{ public:
public: void append_cell(const TupleCellSpec &cell)
void append_cell(const TupleCellSpec &cell) { cells_.push_back(cell); } {
void append_cell(const char *table, const char *field) { append_cell(TupleCellSpec(table, field)); } cells_.push_back(cell);
void append_cell(const char *alias) { append_cell(TupleCellSpec(alias)); } }
int cell_num() const { return static_cast<int>(cells_.size()); } void append_cell(const char *table, const char *field)
const TupleCellSpec &cell_at(int i) const { return cells_[i]; } {
append_cell(TupleCellSpec(table, field));
}
void append_cell(const char *alias)
{
append_cell(TupleCellSpec(alias));
}
int cell_num() const
{
return static_cast<int>(cells_.size());
}
const TupleCellSpec &cell_at(int i) const
{
return cells_[i];
}
private: private:
std::vector<TupleCellSpec> cells_; std::vector<TupleCellSpec> cells_;
}; };
class Tuple class Tuple {
{
public: public:
Tuple() = default; Tuple() = default;
virtual ~Tuple() = default; virtual ~Tuple() = default;
virtual int cell_num() const = 0; virtual int cell_num() const = 0;
virtual RC cell_at(int index, TupleCell &cell) const = 0; virtual RC cell_at(int index, TupleCell &cell) const = 0;
virtual RC find_cell(const TupleCellSpec &spec, TupleCell &cell) const = 0; virtual RC find_cell(const TupleCellSpec &spec, TupleCell &cell) const = 0;
}; };
class RowTuple : public Tuple class RowTuple : public Tuple {
{
public: public:
RowTuple() = default; RowTuple() = default;
virtual ~RowTuple() virtual ~RowTuple()
...@@ -60,7 +72,7 @@ public: ...@@ -60,7 +72,7 @@ public:
} }
speces_.clear(); speces_.clear();
} }
void set_record(Record *record) void set_record(Record *record)
{ {
this->record_ = record; this->record_ = record;
...@@ -103,16 +115,16 @@ public: ...@@ -103,16 +115,16 @@ public:
} }
for (size_t i = 0; i < speces_.size(); ++i) { for (size_t i = 0; i < speces_.size(); ++i) {
const FieldExpr * field_expr = speces_[i]; const FieldExpr *field_expr = speces_[i];
const Field &field = field_expr->field(); const Field &field = field_expr->field();
if (0 == strcmp(field_name, field.field_name())) { if (0 == strcmp(field_name, field.field_name())) {
return cell_at(i, cell); return cell_at(i, cell);
} }
} }
return RC::NOTFOUND; return RC::NOTFOUND;
} }
#if 0 #if 0
RC cell_spec_at(int index, const TupleCellSpec *&spec) const override RC cell_spec_at(int index, const TupleCellSpec *&spec) const override
{ {
if (index < 0 || index >= static_cast<int>(speces_.size())) { if (index < 0 || index >= static_cast<int>(speces_.size())) {
...@@ -122,7 +134,7 @@ public: ...@@ -122,7 +134,7 @@ public:
spec = speces_[index]; spec = speces_[index];
return RC::SUCCESS; return RC::SUCCESS;
} }
#endif #endif
Record &record() Record &record()
{ {
...@@ -133,14 +145,14 @@ public: ...@@ -133,14 +145,14 @@ public:
{ {
return *record_; return *record_;
} }
private: private:
Record *record_ = nullptr; Record *record_ = nullptr;
const Table *table_ = nullptr; const Table *table_ = nullptr;
std::vector<FieldExpr *> speces_; std::vector<FieldExpr *> speces_;
}; };
class ProjectTuple : public Tuple class ProjectTuple : public Tuple {
{
public: public:
ProjectTuple() = default; ProjectTuple() = default;
virtual ~ProjectTuple() virtual ~ProjectTuple()
...@@ -183,7 +195,7 @@ public: ...@@ -183,7 +195,7 @@ public:
return tuple_->find_cell(spec, cell); return tuple_->find_cell(spec, cell);
} }
#if 0 #if 0
RC cell_spec_at(int index, const TupleCellSpec *&spec) const override RC cell_spec_at(int index, const TupleCellSpec *&spec) const override
{ {
if (index < 0 || index >= static_cast<int>(speces_.size())) { if (index < 0 || index >= static_cast<int>(speces_.size())) {
...@@ -192,15 +204,14 @@ public: ...@@ -192,15 +204,14 @@ public:
spec = speces_[index]; spec = speces_[index];
return RC::SUCCESS; return RC::SUCCESS;
} }
#endif #endif
private: private:
std::vector<TupleCellSpec *> speces_; std::vector<TupleCellSpec *> speces_;
Tuple *tuple_ = nullptr; Tuple *tuple_ = nullptr;
}; };
class ValueListTuple : public Tuple class ValueListTuple : public Tuple {
{ public:
public:
ValueListTuple() = default; ValueListTuple() = default;
virtual ~ValueListTuple() = default; virtual ~ValueListTuple() = default;
...@@ -214,7 +225,7 @@ public: ...@@ -214,7 +225,7 @@ public:
return static_cast<int>(cells_.size()); return static_cast<int>(cells_.size());
} }
virtual RC cell_at(int index, TupleCell &cell) const override virtual RC cell_at(int index, TupleCell &cell) const override
{ {
if (index < 0 || index >= cell_num()) { if (index < 0 || index >= cell_num()) {
return RC::NOTFOUND; return RC::NOTFOUND;
...@@ -224,7 +235,7 @@ public: ...@@ -224,7 +235,7 @@ public:
return RC::SUCCESS; return RC::SUCCESS;
} }
virtual RC find_cell(const TupleCellSpec &spec, TupleCell &cell) const override virtual RC find_cell(const TupleCellSpec &spec, TupleCell &cell) const override
{ {
return RC::INTERNAL; return RC::INTERNAL;
} }
...@@ -237,15 +248,20 @@ private: ...@@ -237,15 +248,20 @@ private:
* 将两个tuple合并为一个tuple * 将两个tuple合并为一个tuple
* 在join算子中使用 * 在join算子中使用
*/ */
class JoinedTuple : public Tuple class JoinedTuple : public Tuple {
{
public: public:
JoinedTuple() = default; JoinedTuple() = default;
virtual ~JoinedTuple() = default; virtual ~JoinedTuple() = default;
void set_left(Tuple *left) { left_ = left; } void set_left(Tuple *left)
void set_right(Tuple *right) { right_ = right; } {
left_ = left;
}
void set_right(Tuple *right)
{
right_ = right;
}
int cell_num() const override int cell_num() const override
{ {
return left_->cell_num() + right_->cell_num(); return left_->cell_num() + right_->cell_num();
...@@ -254,7 +270,7 @@ public: ...@@ -254,7 +270,7 @@ public:
RC cell_at(int index, TupleCell &cell) const override RC cell_at(int index, TupleCell &cell) const override
{ {
const int left_cell_num = left_->cell_num(); const int left_cell_num = left_->cell_num();
if (index >0 && index < left_cell_num) { if (index > 0 && index < left_cell_num) {
return left_->cell_at(index, cell); return left_->cell_at(index, cell);
} }
...@@ -274,8 +290,8 @@ public: ...@@ -274,8 +290,8 @@ public:
return right_->find_cell(spec, cell); return right_->find_cell(spec, cell);
} }
private: private:
Tuple * left_ = nullptr; Tuple *left_ = nullptr;
Tuple * right_ = nullptr; Tuple *right_ = nullptr;
}; };
...@@ -136,21 +136,21 @@ const char *TupleCell::data() const ...@@ -136,21 +136,21 @@ const char *TupleCell::data() const
void TupleCell::to_string(std::ostream &os) const void TupleCell::to_string(std::ostream &os) const
{ {
switch (attr_type_) { switch (attr_type_) {
case INTS: { case INTS: {
os << num_value_.int_value_; os << num_value_.int_value_;
} break; } break;
case FLOATS: { case FLOATS: {
os << double2string(num_value_.float_value_); os << double2string(num_value_.float_value_);
} break; } break;
case BOOLEANS: { case BOOLEANS: {
os << num_value_.bool_value_; os << num_value_.bool_value_;
} break; } break;
case CHARS: { case CHARS: {
os << str_value_; os << str_value_;
} break; } break;
default: { default: {
LOG_WARN("unsupported attr type: %d", attr_type_); LOG_WARN("unsupported attr type: %d", attr_type_);
} break; } break;
} }
} }
...@@ -159,14 +159,16 @@ int TupleCell::compare(const TupleCell &other) const ...@@ -159,14 +159,16 @@ int TupleCell::compare(const TupleCell &other) const
if (this->attr_type_ == other.attr_type_) { if (this->attr_type_ == other.attr_type_) {
switch (this->attr_type_) { switch (this->attr_type_) {
case INTS: { case INTS: {
return compare_int((void *)&this->num_value_.int_value_, (void *)&other.num_value_.int_value_); return compare_int((void *)&this->num_value_.int_value_, (void *)&other.num_value_.int_value_);
} break; } break;
case FLOATS: { case FLOATS: {
return compare_float((void *)&this->num_value_.float_value_, (void *)&other.num_value_.float_value_); return compare_float((void *)&this->num_value_.float_value_, (void *)&other.num_value_.float_value_);
} break; } break;
case CHARS: { case CHARS: {
return compare_string((void *)this->str_value_.c_str(), this->str_value_.length(), return compare_string((void *)this->str_value_.c_str(),
(void *)other.str_value_.c_str(), other.str_value_.length()); this->str_value_.length(),
(void *)other.str_value_.c_str(),
other.str_value_.length());
} break; } break;
case BOOLEANS: { case BOOLEANS: {
return compare_int((void *)&this->num_value_.bool_value_, (void *)&other.num_value_.bool_value_); return compare_int((void *)&this->num_value_.bool_value_, (void *)&other.num_value_.bool_value_);
...@@ -183,10 +185,9 @@ int TupleCell::compare(const TupleCell &other) const ...@@ -183,10 +185,9 @@ int TupleCell::compare(const TupleCell &other) const
return compare_float((void *)&this->num_value_.float_value_, (void *)&other_data); return compare_float((void *)&this->num_value_.float_value_, (void *)&other_data);
} }
LOG_WARN("not supported"); LOG_WARN("not supported");
return -1; // TODO return rc? return -1; // TODO return rc?
} }
int TupleCell::get_int() const int TupleCell::get_int() const
{ {
switch (attr_type_) { switch (attr_type_) {
......
...@@ -18,15 +18,23 @@ See the Mulan PSL v2 for more details. */ ...@@ -18,15 +18,23 @@ See the Mulan PSL v2 for more details. */
#include "storage/common/table.h" #include "storage/common/table.h"
#include "storage/common/field_meta.h" #include "storage/common/field_meta.h"
class TupleCellSpec class TupleCellSpec {
{ public:
public:
TupleCellSpec(const char *table_name, const char *field_name, const char *alias = nullptr); TupleCellSpec(const char *table_name, const char *field_name, const char *alias = nullptr);
TupleCellSpec(const char *alias); TupleCellSpec(const char *alias);
const char *table_name() const { return table_name_.c_str(); } const char *table_name() const
const char *field_name() const { return field_name_.c_str(); } {
const char *alias() const { return alias_.c_str(); } return table_name_.c_str();
}
const char *field_name() const
{
return field_name_.c_str();
}
const char *alias() const
{
return alias_.c_str();
}
private: private:
std::string table_name_; std::string table_name_;
...@@ -38,16 +46,13 @@ private: ...@@ -38,16 +46,13 @@ private:
* 表示tuple中某个元素的值 * 表示tuple中某个元素的值
* @note 可以与value做合并 * @note 可以与value做合并
*/ */
class TupleCell class TupleCell {
{ public:
public:
TupleCell() = default; TupleCell() = default;
TupleCell(FieldMeta *meta, char *data, int length = 4) TupleCell(FieldMeta *meta, char *data, int length = 4) : TupleCell(meta->type(), data)
: TupleCell(meta->type(), data)
{} {}
TupleCell(AttrType attr_type, char *data, int length = 4) TupleCell(AttrType attr_type, char *data, int length = 4) : attr_type_(attr_type)
: attr_type_(attr_type)
{ {
this->set_data(data, length); this->set_data(data, length);
} }
...@@ -55,9 +60,15 @@ public: ...@@ -55,9 +60,15 @@ public:
TupleCell(const TupleCell &other) = default; TupleCell(const TupleCell &other) = default;
TupleCell &operator=(const TupleCell &other) = default; TupleCell &operator=(const TupleCell &other) = default;
void set_type(AttrType type) { this->attr_type_ = type; } void set_type(AttrType type)
{
this->attr_type_ = type;
}
void set_data(char *data, int length); void set_data(char *data, int length);
void set_data(const char *data, int length) { this->set_data(const_cast<char *>(data), length); } void set_data(const char *data, int length)
{
this->set_data(const_cast<char *>(data), length);
}
void set_int(int val); void set_int(int val);
void set_float(float val); void set_float(float val);
void set_boolean(bool val); void set_boolean(bool val);
...@@ -69,7 +80,10 @@ public: ...@@ -69,7 +80,10 @@ public:
int compare(const TupleCell &other) const; int compare(const TupleCell &other) const;
const char *data() const; const char *data() const;
int length() const { return length_; } int length() const
{
return length_;
}
AttrType attr_type() const AttrType attr_type() const
{ {
...@@ -85,15 +99,15 @@ public: ...@@ -85,15 +99,15 @@ public:
float get_float() const; float get_float() const;
std::string get_string() const; std::string get_string() const;
bool get_boolean() const; bool get_boolean() const;
private: private:
AttrType attr_type_ = UNDEFINED; AttrType attr_type_ = UNDEFINED;
int length_; int length_;
union { union {
int int_value_; int int_value_;
float float_value_; float float_value_;
bool bool_value_; bool bool_value_;
} num_value_; } num_value_;
std::string str_value_; std::string str_value_;
}; };
...@@ -14,7 +14,5 @@ See the Mulan PSL v2 for more details. */ ...@@ -14,7 +14,5 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/delete_logical_operator.h" #include "sql/operator/delete_logical_operator.h"
DeleteLogicalOperator::DeleteLogicalOperator(Table *table) DeleteLogicalOperator::DeleteLogicalOperator(Table *table) : table_(table)
: table_(table) {}
{
}
...@@ -16,15 +16,20 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,15 +16,20 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/logical_operator.h" #include "sql/operator/logical_operator.h"
class DeleteLogicalOperator : public LogicalOperator class DeleteLogicalOperator : public LogicalOperator {
{
public: public:
DeleteLogicalOperator(Table *table); DeleteLogicalOperator(Table *table);
virtual ~DeleteLogicalOperator() = default; virtual ~DeleteLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::DELETE; } LogicalOperatorType type() const override
Table *table() const { return table_; } {
return LogicalOperatorType::DELETE;
}
Table *table() const
{
return table_;
}
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
}; };
...@@ -41,7 +41,7 @@ RC DeletePhysicalOperator::next() ...@@ -41,7 +41,7 @@ RC DeletePhysicalOperator::next()
if (children_.empty()) { if (children_.empty()) {
return RC::RECORD_EOF; return RC::RECORD_EOF;
} }
PhysicalOperator *child = children_[0].get(); PhysicalOperator *child = children_[0].get();
while (RC::SUCCESS == (rc = child->next())) { while (RC::SUCCESS == (rc = child->next())) {
Tuple *tuple = child->current_tuple(); Tuple *tuple = child->current_tuple();
......
...@@ -20,24 +20,27 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,24 +20,27 @@ See the Mulan PSL v2 for more details. */
class Trx; class Trx;
class DeleteStmt; class DeleteStmt;
class DeletePhysicalOperator : public PhysicalOperator class DeletePhysicalOperator : public PhysicalOperator {
{
public: public:
DeletePhysicalOperator(Table *table, Trx *trx) DeletePhysicalOperator(Table *table, Trx *trx) : table_(table), trx_(trx)
: table_(table), trx_(trx)
{} {}
virtual ~DeletePhysicalOperator() = default; virtual ~DeletePhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::DELETE; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::DELETE;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
Tuple * current_tuple() override { Tuple *current_tuple() override
{
return nullptr; return nullptr;
} }
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
Trx *trx_ = nullptr; Trx *trx_ = nullptr;
......
...@@ -16,13 +16,15 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,13 +16,15 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/logical_operator.h" #include "sql/operator/logical_operator.h"
class ExplainLogicalOperator : public LogicalOperator class ExplainLogicalOperator : public LogicalOperator {
{
public: public:
ExplainLogicalOperator() = default; ExplainLogicalOperator() = default;
virtual ~ExplainLogicalOperator() = default; virtual ~ExplainLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::EXPLAIN; } LogicalOperatorType type() const override
{
return LogicalOperatorType::EXPLAIN;
}
private: private:
}; };
...@@ -37,19 +37,19 @@ RC ExplainPhysicalOperator::next() ...@@ -37,19 +37,19 @@ RC ExplainPhysicalOperator::next()
if (!physical_plan_.empty()) { if (!physical_plan_.empty()) {
return RC::RECORD_EOF; return RC::RECORD_EOF;
} }
stringstream ss; stringstream ss;
ss << "OPERATOR(NAME)\n"; ss << "OPERATOR(NAME)\n";
int level = 0; int level = 0;
std::vector<bool> ends; std::vector<bool> ends;
ends.push_back(true); ends.push_back(true);
const auto children_size = static_cast<int>(children_.size()); const auto children_size = static_cast<int>(children_.size());
for (int i = 0; i < children_size - 1; i++) { for (int i = 0; i < children_size - 1; i++) {
to_string(ss, children_[i].get(), level, false/*last_child*/, ends); to_string(ss, children_[i].get(), level, false /*last_child*/, ends);
} }
if (children_size > 0) { if (children_size > 0) {
to_string(ss, children_[children_size - 1].get(), level, true/*last_child*/, ends); to_string(ss, children_[children_size - 1].get(), level, true /*last_child*/, ends);
} }
physical_plan_ = ss.str(); physical_plan_ = ss.str();
...@@ -75,9 +75,10 @@ Tuple *ExplainPhysicalOperator::current_tuple() ...@@ -75,9 +75,10 @@ Tuple *ExplainPhysicalOperator::current_tuple()
* @param last_child 当前算子是否是当前兄弟节点中最后一个节点 * @param last_child 当前算子是否是当前兄弟节点中最后一个节点
* @param ends 表示当前某个层级上的算子,是否已经没有其它的节点,以判断使用什么打印符号 * @param ends 表示当前某个层级上的算子,是否已经没有其它的节点,以判断使用什么打印符号
*/ */
void ExplainPhysicalOperator::to_string(std::ostream &os, PhysicalOperator *oper, int level, bool last_child, std::vector<bool> &ends) void ExplainPhysicalOperator::to_string(
std::ostream &os, PhysicalOperator *oper, int level, bool last_child, std::vector<bool> &ends)
{ {
for (int i = 0; i < level-1; i++) { for (int i = 0; i < level - 1; i++) {
if (ends[i]) { if (ends[i]) {
os << " "; os << " ";
} else { } else {
...@@ -108,9 +109,9 @@ void ExplainPhysicalOperator::to_string(std::ostream &os, PhysicalOperator *oper ...@@ -108,9 +109,9 @@ void ExplainPhysicalOperator::to_string(std::ostream &os, PhysicalOperator *oper
std::vector<std::unique_ptr<PhysicalOperator>> &children = oper->children(); std::vector<std::unique_ptr<PhysicalOperator>> &children = oper->children();
const auto size = static_cast<int>(children.size()); const auto size = static_cast<int>(children.size());
for (auto i = 0; i < size - 1; i++) { for (auto i = 0; i < size - 1; i++) {
to_string(os, children[i].get(), level + 1, false/*last_child*/, ends); to_string(os, children[i].get(), level + 1, false /*last_child*/, ends);
} }
if (size > 0) { if (size > 0) {
to_string(os, children[size - 1].get(), level + 1, true/*last_child*/, ends); to_string(os, children[size - 1].get(), level + 1, true /*last_child*/, ends);
} }
} }
...@@ -16,14 +16,16 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,14 +16,16 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/physical_operator.h" #include "sql/operator/physical_operator.h"
class ExplainPhysicalOperator : public PhysicalOperator class ExplainPhysicalOperator : public PhysicalOperator {
{
public: public:
ExplainPhysicalOperator() = default; ExplainPhysicalOperator() = default;
virtual ~ExplainPhysicalOperator() = default; virtual ~ExplainPhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::EXPLAIN; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::EXPLAIN;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
......
...@@ -15,11 +15,9 @@ See the Mulan PSL v2 for more details. */ ...@@ -15,11 +15,9 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/index_scan_physical_operator.h" #include "sql/operator/index_scan_physical_operator.h"
#include "storage/index/index.h" #include "storage/index/index.h"
IndexScanPhysicalOperator::IndexScanPhysicalOperator(const Table *table, Index *index, IndexScanPhysicalOperator::IndexScanPhysicalOperator(const Table *table, Index *index, const TupleCell *left_cell,
const TupleCell *left_cell, bool left_inclusive, bool left_inclusive, const TupleCell *right_cell, bool right_inclusive)
const TupleCell *right_cell, bool right_inclusive) : table_(table), index_(index), left_inclusive_(left_inclusive), right_inclusive_(right_inclusive)
: table_(table), index_(index),
left_inclusive_(left_inclusive), right_inclusive_(right_inclusive)
{ {
if (left_cell) { if (left_cell) {
left_cell_ = *left_cell; left_cell_ = *left_cell;
...@@ -35,9 +33,12 @@ RC IndexScanPhysicalOperator::open() ...@@ -35,9 +33,12 @@ RC IndexScanPhysicalOperator::open()
return RC::INTERNAL; return RC::INTERNAL;
} }
IndexScanner *index_scanner = index_->create_scanner(left_cell_.data(),
IndexScanner *index_scanner = index_->create_scanner(left_cell_.data(), left_cell_.length(), left_inclusive_, left_cell_.length(),
right_cell_.data(), right_cell_.length(), right_inclusive_); left_inclusive_,
right_cell_.data(),
right_cell_.length(),
right_inclusive_);
if (nullptr == index_scanner) { if (nullptr == index_scanner) {
LOG_WARN("failed to create index scanner"); LOG_WARN("failed to create index scanner");
return RC::INTERNAL; return RC::INTERNAL;
...@@ -52,7 +53,7 @@ RC IndexScanPhysicalOperator::open() ...@@ -52,7 +53,7 @@ RC IndexScanPhysicalOperator::open()
index_scanner_ = index_scanner; index_scanner_ = index_scanner;
tuple_.set_schema(table_, table_->table_meta().field_metas()); tuple_.set_schema(table_, table_->table_meta().field_metas());
return RC::SUCCESS; return RC::SUCCESS;
} }
...@@ -89,7 +90,7 @@ RC IndexScanPhysicalOperator::close() ...@@ -89,7 +90,7 @@ RC IndexScanPhysicalOperator::close()
return RC::SUCCESS; return RC::SUCCESS;
} }
Tuple * IndexScanPhysicalOperator::current_tuple() Tuple *IndexScanPhysicalOperator::current_tuple()
{ {
tuple_.set_record(&current_record_); tuple_.set_record(&current_record_);
return &tuple_; return &tuple_;
......
...@@ -17,31 +17,32 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,31 +17,32 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/physical_operator.h" #include "sql/operator/physical_operator.h"
#include "sql/expr/tuple.h" #include "sql/expr/tuple.h"
class IndexScanPhysicalOperator : public PhysicalOperator class IndexScanPhysicalOperator : public PhysicalOperator {
{ public:
public: IndexScanPhysicalOperator(const Table *table, Index *index, const TupleCell *left_cell, bool left_inclusive,
IndexScanPhysicalOperator(const Table *table, Index *index, const TupleCell *right_cell, bool right_inclusive);
const TupleCell *left_cell, bool left_inclusive,
const TupleCell *right_cell, bool right_inclusive);
virtual ~IndexScanPhysicalOperator() = default; virtual ~IndexScanPhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::INDEX_SCAN; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::INDEX_SCAN;
}
std::string param() const override; std::string param() const override;
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
Tuple * current_tuple() override; Tuple *current_tuple() override;
void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs); void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs);
private: private:
// 与TableScanPhysicalOperator代码相同,可以优化 // 与TableScanPhysicalOperator代码相同,可以优化
RC filter(RowTuple &tuple, bool &result); RC filter(RowTuple &tuple, bool &result);
private: private:
const Table *table_ = nullptr; const Table *table_ = nullptr;
Index *index_ = nullptr; Index *index_ = nullptr;
......
...@@ -22,7 +22,7 @@ RC InsertPhysicalOperator::open() ...@@ -22,7 +22,7 @@ RC InsertPhysicalOperator::open()
Table *table = insert_stmt_->table(); Table *table = insert_stmt_->table();
const Value *values = insert_stmt_->values(); const Value *values = insert_stmt_->values();
int value_amount = insert_stmt_->value_amount(); int value_amount = insert_stmt_->value_amount();
return table->insert_record(nullptr, value_amount, values); // TODO trx return table->insert_record(nullptr, value_amount, values); // TODO trx
} }
RC InsertPhysicalOperator::next() RC InsertPhysicalOperator::next()
......
...@@ -20,17 +20,18 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,17 +20,18 @@ See the Mulan PSL v2 for more details. */
class InsertStmt; class InsertStmt;
class InsertPhysicalOperator : public PhysicalOperator class InsertPhysicalOperator : public PhysicalOperator {
{
public: public:
InsertPhysicalOperator(InsertStmt *insert_stmt) InsertPhysicalOperator(InsertStmt *insert_stmt) : insert_stmt_(insert_stmt)
: insert_stmt_(insert_stmt)
{} {}
virtual ~InsertPhysicalOperator() = default; virtual ~InsertPhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::INSERT; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::INSERT;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
......
...@@ -16,12 +16,15 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,12 +16,15 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/logical_operator.h" #include "sql/operator/logical_operator.h"
class JoinLogicalOperator : public LogicalOperator class JoinLogicalOperator : public LogicalOperator {
{
public: public:
JoinLogicalOperator() = default; JoinLogicalOperator() = default;
virtual ~JoinLogicalOperator() = default; virtual ~JoinLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::JOIN; } LogicalOperatorType type() const override
{
return LogicalOperatorType::JOIN;
}
private: private:
}; };
...@@ -15,8 +15,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -15,8 +15,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/join_physical_operator.h" #include "sql/operator/join_physical_operator.h"
NestedLoopJoinPhysicalOperator::NestedLoopJoinPhysicalOperator() NestedLoopJoinPhysicalOperator::NestedLoopJoinPhysicalOperator()
{ {}
}
RC NestedLoopJoinPhysicalOperator::open() RC NestedLoopJoinPhysicalOperator::open()
{ {
...@@ -50,7 +49,7 @@ RC NestedLoopJoinPhysicalOperator::next() ...@@ -50,7 +49,7 @@ RC NestedLoopJoinPhysicalOperator::next()
return rc; return rc;
} }
} else { } else {
return rc; // got one tuple from right return rc; // got one tuple from right
} }
} }
...@@ -83,7 +82,7 @@ RC NestedLoopJoinPhysicalOperator::close() ...@@ -83,7 +82,7 @@ RC NestedLoopJoinPhysicalOperator::close()
return rc; return rc;
} }
Tuple * NestedLoopJoinPhysicalOperator::current_tuple() Tuple *NestedLoopJoinPhysicalOperator::current_tuple()
{ {
return &joined_tuple_; return &joined_tuple_;
} }
......
...@@ -22,14 +22,16 @@ See the Mulan PSL v2 for more details. */ ...@@ -22,14 +22,16 @@ See the Mulan PSL v2 for more details. */
* 最简单的两表(称为左表、右表)join算子 * 最简单的两表(称为左表、右表)join算子
* 依次遍历左表的每一行,然后关联右表的每一行 * 依次遍历左表的每一行,然后关联右表的每一行
*/ */
class NestedLoopJoinPhysicalOperator : public PhysicalOperator class NestedLoopJoinPhysicalOperator : public PhysicalOperator {
{
public: public:
NestedLoopJoinPhysicalOperator(); NestedLoopJoinPhysicalOperator();
virtual ~NestedLoopJoinPhysicalOperator() = default; virtual ~NestedLoopJoinPhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::NESTED_LOOP_JOIN; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::NESTED_LOOP_JOIN;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
...@@ -38,7 +40,7 @@ public: ...@@ -38,7 +40,7 @@ public:
private: private:
RC left_next(); //! 左表遍历下一条数据 RC left_next(); //! 左表遍历下一条数据
RC right_next(); //! 右表遍历下一条数据,如果上一轮结束了就重新开始新的一轮 RC right_next(); //! 右表遍历下一条数据,如果上一轮结束了就重新开始新的一轮
private: private:
//! 左表右表的真实对象是在PhysicalOperator::children_中,这里是为了写的时候更简单 //! 左表右表的真实对象是在PhysicalOperator::children_中,这里是为了写的时候更简单
PhysicalOperator *left_ = nullptr; PhysicalOperator *left_ = nullptr;
......
...@@ -15,8 +15,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -15,8 +15,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/logical_operator.h" #include "sql/operator/logical_operator.h"
LogicalOperator::~LogicalOperator() LogicalOperator::~LogicalOperator()
{ {}
}
void LogicalOperator::add_child(std::unique_ptr<LogicalOperator> oper) void LogicalOperator::add_child(std::unique_ptr<LogicalOperator> oper)
{ {
......
...@@ -19,8 +19,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -19,8 +19,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/expr/expression.h" #include "sql/expr/expression.h"
enum class LogicalOperatorType enum class LogicalOperatorType {
{
TABLE_GET, TABLE_GET,
PREDICATE, PREDICATE,
PROJECTION, PROJECTION,
...@@ -33,18 +32,23 @@ enum class LogicalOperatorType ...@@ -33,18 +32,23 @@ enum class LogicalOperatorType
* 逻辑算子描述当前执行计划要做什么 * 逻辑算子描述当前执行计划要做什么
* 可以看OptimizeStage中相关的代码 * 可以看OptimizeStage中相关的代码
*/ */
class LogicalOperator class LogicalOperator {
{
public: public:
LogicalOperator() = default; LogicalOperator() = default;
virtual ~LogicalOperator(); virtual ~LogicalOperator();
virtual LogicalOperatorType type() const = 0; virtual LogicalOperatorType type() const = 0;
void add_child(std::unique_ptr<LogicalOperator> oper); void add_child(std::unique_ptr<LogicalOperator> oper);
std::vector<std::unique_ptr<LogicalOperator>> & children() { return children_; } std::vector<std::unique_ptr<LogicalOperator>> &children()
std::vector<std::unique_ptr<Expression>> &expressions() { return expressions_; } {
return children_;
}
std::vector<std::unique_ptr<Expression>> &expressions()
{
return expressions_;
}
protected: protected:
std::vector<std::unique_ptr<LogicalOperator>> children_; std::vector<std::unique_ptr<LogicalOperator>> children_;
std::vector<std::unique_ptr<Expression>> expressions_; std::vector<std::unique_ptr<Expression>> expressions_;
......
...@@ -37,8 +37,7 @@ std::string physical_operator_type_name(PhysicalOperatorType type) ...@@ -37,8 +37,7 @@ std::string physical_operator_type_name(PhysicalOperatorType type)
} }
PhysicalOperator::~PhysicalOperator() PhysicalOperator::~PhysicalOperator()
{ {}
}
std::string PhysicalOperator::name() const std::string PhysicalOperator::name() const
{ {
......
...@@ -24,8 +24,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -24,8 +24,7 @@ See the Mulan PSL v2 for more details. */
class Record; class Record;
class TupleCellSpec; class TupleCellSpec;
enum class PhysicalOperatorType enum class PhysicalOperatorType {
{
TABLE_SCAN, TABLE_SCAN,
INDEX_SCAN, INDEX_SCAN,
NESTED_LOOP_JOIN, NESTED_LOOP_JOIN,
...@@ -40,8 +39,7 @@ enum class PhysicalOperatorType ...@@ -40,8 +39,7 @@ enum class PhysicalOperatorType
/** /**
* 与LogicalOperator对应,物理算子描述执行计划将如何执行 * 与LogicalOperator对应,物理算子描述执行计划将如何执行
*/ */
class PhysicalOperator class PhysicalOperator {
{
public: public:
PhysicalOperator() PhysicalOperator()
{} {}
...@@ -55,18 +53,22 @@ public: ...@@ -55,18 +53,22 @@ public:
virtual std::string param() const; virtual std::string param() const;
virtual PhysicalOperatorType type() const = 0; virtual PhysicalOperatorType type() const = 0;
virtual RC open() = 0; virtual RC open() = 0;
virtual RC next() = 0; virtual RC next() = 0;
virtual RC close() = 0; virtual RC close() = 0;
virtual Tuple * current_tuple() = 0; virtual Tuple *current_tuple() = 0;
void add_child(std::unique_ptr<PhysicalOperator> oper) { void add_child(std::unique_ptr<PhysicalOperator> oper)
{
children_.emplace_back(std::move(oper)); children_.emplace_back(std::move(oper));
} }
std::vector<std::unique_ptr<PhysicalOperator>> &children() { return children_; } std::vector<std::unique_ptr<PhysicalOperator>> &children()
{
return children_;
}
protected: protected:
std::vector<std::unique_ptr<PhysicalOperator>> children_; std::vector<std::unique_ptr<PhysicalOperator>> children_;
......
...@@ -17,11 +17,13 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,11 +17,13 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/logical_operator.h" #include "sql/operator/logical_operator.h"
#include "sql/expr/expression.h" #include "sql/expr/expression.h"
class PredicateLogicalOperator : public LogicalOperator class PredicateLogicalOperator : public LogicalOperator {
{
public: public:
PredicateLogicalOperator(std::unique_ptr<Expression> expression); PredicateLogicalOperator(std::unique_ptr<Expression> expression);
virtual ~PredicateLogicalOperator() = default; virtual ~PredicateLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::PREDICATE; } LogicalOperatorType type() const override
{
return LogicalOperatorType::PREDICATE;
}
}; };
...@@ -18,8 +18,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -18,8 +18,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/stmt/filter_stmt.h" #include "sql/stmt/filter_stmt.h"
#include "storage/common/field.h" #include "storage/common/field.h"
PredicatePhysicalOperator::PredicatePhysicalOperator(std::unique_ptr<Expression> expr) PredicatePhysicalOperator::PredicatePhysicalOperator(std::unique_ptr<Expression> expr) : expression_(std::move(expr))
: expression_(std::move(expr))
{ {
ASSERT(expression_->value_type() == BOOLEANS, "predicate's expression should be BOOLEAN type"); ASSERT(expression_->value_type() == BOOLEANS, "predicate's expression should be BOOLEAN type");
} }
...@@ -38,7 +37,7 @@ RC PredicatePhysicalOperator::next() ...@@ -38,7 +37,7 @@ RC PredicatePhysicalOperator::next()
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
PhysicalOperator *oper = children_.front().get(); PhysicalOperator *oper = children_.front().get();
while (RC::SUCCESS == (rc = oper->next())) { while (RC::SUCCESS == (rc = oper->next())) {
Tuple *tuple = oper->current_tuple(); Tuple *tuple = oper->current_tuple();
if (nullptr == tuple) { if (nullptr == tuple) {
...@@ -66,8 +65,7 @@ RC PredicatePhysicalOperator::close() ...@@ -66,8 +65,7 @@ RC PredicatePhysicalOperator::close()
return RC::SUCCESS; return RC::SUCCESS;
} }
Tuple * PredicatePhysicalOperator::current_tuple() Tuple *PredicatePhysicalOperator::current_tuple()
{ {
return children_[0]->current_tuple(); return children_[0]->current_tuple();
} }
...@@ -20,21 +20,23 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,21 +20,23 @@ See the Mulan PSL v2 for more details. */
class FilterStmt; class FilterStmt;
class PredicatePhysicalOperator : public PhysicalOperator class PredicatePhysicalOperator : public PhysicalOperator {
{
public: public:
PredicatePhysicalOperator(std::unique_ptr<Expression> expr); PredicatePhysicalOperator(std::unique_ptr<Expression> expr);
virtual ~PredicatePhysicalOperator() = default; virtual ~PredicatePhysicalOperator() = default;
PhysicalOperatorType type() const override { return PhysicalOperatorType::PREDICATE; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::PREDICATE;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
Tuple * current_tuple() override; Tuple *current_tuple() override;
private: private:
std::unique_ptr<Expression> expression_; std::unique_ptr<Expression> expression_;
}; };
...@@ -14,7 +14,5 @@ See the Mulan PSL v2 for more details. */ ...@@ -14,7 +14,5 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/project_logical_operator.h" #include "sql/operator/project_logical_operator.h"
ProjectLogicalOperator::ProjectLogicalOperator(const std::vector<Field> &fields) ProjectLogicalOperator::ProjectLogicalOperator(const std::vector<Field> &fields) : fields_(fields)
: fields_(fields) {}
{
}
...@@ -24,16 +24,21 @@ See the Mulan PSL v2 for more details. */ ...@@ -24,16 +24,21 @@ See the Mulan PSL v2 for more details. */
/** /**
* project 表示投影运算 * project 表示投影运算
*/ */
class ProjectLogicalOperator : public LogicalOperator class ProjectLogicalOperator : public LogicalOperator {
{
public: public:
ProjectLogicalOperator(const std::vector<Field> &fields); ProjectLogicalOperator(const std::vector<Field> &fields);
virtual ~ProjectLogicalOperator() = default; virtual ~ProjectLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::PROJECTION; } LogicalOperatorType type() const override
{
return LogicalOperatorType::PROJECTION;
}
const std::vector<Field> &fields() const
{
return fields_;
}
const std::vector<Field> &fields() const { return fields_; }
private: private:
//! 投影映射的字段名称 //! 投影映射的字段名称
//! 并不是所有的select都会查看表字段,也可能是常量数字、字符串, //! 并不是所有的select都会查看表字段,也可能是常量数字、字符串,
......
...@@ -17,8 +17,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,8 +17,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/operator/physical_operator.h" #include "sql/operator/physical_operator.h"
#include "rc.h" #include "rc.h"
class ProjectPhysicalOperator : public PhysicalOperator class ProjectPhysicalOperator : public PhysicalOperator {
{
public: public:
ProjectPhysicalOperator() ProjectPhysicalOperator()
{} {}
...@@ -27,8 +26,11 @@ public: ...@@ -27,8 +26,11 @@ public:
void add_projection(const Table *table, const FieldMeta *field); void add_projection(const Table *table, const FieldMeta *field);
PhysicalOperatorType type() const override { return PhysicalOperatorType::PROJECT; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::PROJECT;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
...@@ -38,7 +40,8 @@ public: ...@@ -38,7 +40,8 @@ public:
return tuple_.cell_num(); return tuple_.cell_num();
} }
Tuple * current_tuple() override; Tuple *current_tuple() override;
private: private:
ProjectTuple tuple_; ProjectTuple tuple_;
}; };
...@@ -17,9 +17,8 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,9 +17,8 @@ See the Mulan PSL v2 for more details. */
#include <vector> #include <vector>
#include "sql/operator/physical_operator.h" #include "sql/operator/physical_operator.h"
class StringListPhysicalOperator : public PhysicalOperator class StringListPhysicalOperator : public PhysicalOperator {
{ public:
public:
StringListPhysicalOperator() StringListPhysicalOperator()
{} {}
...@@ -37,18 +36,21 @@ public: ...@@ -37,18 +36,21 @@ public:
} }
template <typename T> template <typename T>
void append( const T &v) void append(const T &v)
{ {
strings_.emplace_back(1, v); strings_.emplace_back(1, v);
} }
PhysicalOperatorType type() const override { return PhysicalOperatorType::STRING_LIST; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::STRING_LIST;
}
RC open() override RC open() override
{ {
return RC::SUCCESS; return RC::SUCCESS;
} }
RC next() override RC next() override
{ {
if (!started_) { if (!started_) {
...@@ -56,16 +58,17 @@ public: ...@@ -56,16 +58,17 @@ public:
iterator_ = strings_.begin(); iterator_ = strings_.begin();
} else if (iterator_ != strings_.end()) { } else if (iterator_ != strings_.end()) {
++iterator_; ++iterator_;
} }
return iterator_ == strings_.end() ? RC::RECORD_EOF : RC::SUCCESS; return iterator_ == strings_.end() ? RC::RECORD_EOF : RC::SUCCESS;
} }
virtual RC close() override virtual RC close() override
{ {
iterator_ = strings_.end(); iterator_ = strings_.end();
return RC::SUCCESS; return RC::SUCCESS;
} }
virtual Tuple * current_tuple() override virtual Tuple *current_tuple() override
{ {
if (iterator_ == strings_.end()) { if (iterator_ == strings_.end()) {
return nullptr; return nullptr;
...@@ -82,6 +85,7 @@ public: ...@@ -82,6 +85,7 @@ public:
tuple_.set_cells(cells); tuple_.set_cells(cells);
return &tuple_; return &tuple_;
} }
private: private:
using StringList = std::vector<std::string>; using StringList = std::vector<std::string>;
using StringListList = std::vector<StringList>; using StringListList = std::vector<StringList>;
......
...@@ -20,19 +20,27 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,19 +20,27 @@ See the Mulan PSL v2 for more details. */
* 表示从表中获取数据的算子 * 表示从表中获取数据的算子
* 比如使用全表扫描、通过索引获取数据等 * 比如使用全表扫描、通过索引获取数据等
*/ */
class TableGetLogicalOperator : public LogicalOperator class TableGetLogicalOperator : public LogicalOperator {
{
public: public:
TableGetLogicalOperator(Table *table, const std::vector<Field> &fields); TableGetLogicalOperator(Table *table, const std::vector<Field> &fields);
virtual ~TableGetLogicalOperator() = default; virtual ~TableGetLogicalOperator() = default;
LogicalOperatorType type() const override { return LogicalOperatorType::TABLE_GET; } LogicalOperatorType type() const override
{
return LogicalOperatorType::TABLE_GET;
}
Table *table() const { return table_; } Table *table() const
{
return table_;
}
void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs); void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs);
std::vector<std::unique_ptr<Expression>> &predicates() { return predicates_; } std::vector<std::unique_ptr<Expression>> &predicates()
{
return predicates_;
}
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
std::vector<Field> fields_; std::vector<Field> fields_;
......
...@@ -38,7 +38,7 @@ RC TableScanPhysicalOperator::next() ...@@ -38,7 +38,7 @@ RC TableScanPhysicalOperator::next()
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
return rc; return rc;
} }
tuple_.set_record(&current_record_); tuple_.set_record(&current_record_);
rc = filter(tuple_, filter_result); rc = filter(tuple_, filter_result);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
...@@ -59,7 +59,7 @@ RC TableScanPhysicalOperator::close() ...@@ -59,7 +59,7 @@ RC TableScanPhysicalOperator::close()
return record_scanner_.close_scan(); return record_scanner_.close_scan();
} }
Tuple * TableScanPhysicalOperator::current_tuple() Tuple *TableScanPhysicalOperator::current_tuple()
{ {
tuple_.set_record(&current_record_); tuple_.set_record(&current_record_);
return &tuple_; return &tuple_;
...@@ -67,7 +67,7 @@ Tuple * TableScanPhysicalOperator::current_tuple() ...@@ -67,7 +67,7 @@ Tuple * TableScanPhysicalOperator::current_tuple()
std::string TableScanPhysicalOperator::param() const std::string TableScanPhysicalOperator::param() const
{ {
return table_->name(); return table_->name();
} }
void TableScanPhysicalOperator::set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs) void TableScanPhysicalOperator::set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs)
......
...@@ -20,30 +20,31 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,30 +20,31 @@ See the Mulan PSL v2 for more details. */
class Table; class Table;
class TableScanPhysicalOperator : public PhysicalOperator class TableScanPhysicalOperator : public PhysicalOperator {
{
public: public:
TableScanPhysicalOperator(Table *table) TableScanPhysicalOperator(Table *table) : table_(table)
: table_(table)
{} {}
virtual ~TableScanPhysicalOperator() = default; virtual ~TableScanPhysicalOperator() = default;
std::string param() const override; std::string param() const override;
PhysicalOperatorType type() const override { return PhysicalOperatorType::TABLE_SCAN; } PhysicalOperatorType type() const override
{
return PhysicalOperatorType::TABLE_SCAN;
}
RC open() override; RC open() override;
RC next() override; RC next() override;
RC close() override; RC close() override;
Tuple * current_tuple() override; Tuple *current_tuple() override;
void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs); void set_predicates(std::vector<std::unique_ptr<Expression>> &&exprs);
private: private:
RC filter(RowTuple &tuple, bool &result); RC filter(RowTuple &tuple, bool &result);
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
RecordFileScanner record_scanner_; RecordFileScanner record_scanner_;
......
...@@ -19,12 +19,11 @@ See the Mulan PSL v2 for more details. */ ...@@ -19,12 +19,11 @@ See the Mulan PSL v2 for more details. */
class LogicalOperator; class LogicalOperator;
class ComparisonSimplificationRule : public ExpressionRewriteRule class ComparisonSimplificationRule : public ExpressionRewriteRule {
{ public:
public:
ComparisonSimplificationRule() = default; ComparisonSimplificationRule() = default;
virtual ~ComparisonSimplificationRule() = default; virtual ~ComparisonSimplificationRule() = default;
RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) override; RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) override;
private: private:
......
...@@ -37,7 +37,7 @@ RC ConjunctionSimplificationRule::rewrite(std::unique_ptr<Expression> &expr, boo ...@@ -37,7 +37,7 @@ RC ConjunctionSimplificationRule::rewrite(std::unique_ptr<Expression> &expr, boo
std::vector<std::unique_ptr<Expression>> &child_exprs = conjunction_expr->children(); std::vector<std::unique_ptr<Expression>> &child_exprs = conjunction_expr->children();
// 先看看有没有能够直接去掉的表达式。比如AND时恒为true的表达式可以删除 // 先看看有没有能够直接去掉的表达式。比如AND时恒为true的表达式可以删除
// 或者是否可以直接计算出当前表达式的值。比如AND时,如果有一个表达式为false,那么整个表达式就是false // 或者是否可以直接计算出当前表达式的值。比如AND时,如果有一个表达式为false,那么整个表达式就是false
for (auto iter = child_exprs.begin(); iter != child_exprs.end(); ) { for (auto iter = child_exprs.begin(); iter != child_exprs.end();) {
bool constant_value = false; bool constant_value = false;
rc = try_to_get_bool_constant(*iter, constant_value); rc = try_to_get_bool_constant(*iter, constant_value);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
...@@ -68,7 +68,6 @@ RC ConjunctionSimplificationRule::rewrite(std::unique_ptr<Expression> &expr, boo ...@@ -68,7 +68,6 @@ RC ConjunctionSimplificationRule::rewrite(std::unique_ptr<Expression> &expr, boo
child_exprs.erase(iter); child_exprs.erase(iter);
} }
} }
} }
if (child_exprs.size() == 1) { if (child_exprs.size() == 1) {
LOG_TRACE("conjunction expression has only 1 child"); LOG_TRACE("conjunction expression has only 1 child");
......
...@@ -23,12 +23,12 @@ class LogicalOperator; ...@@ -23,12 +23,12 @@ class LogicalOperator;
* 简化多个表达式联结的运算 * 简化多个表达式联结的运算
* 比如只有一个表达式,或者表达式可以直接出来 * 比如只有一个表达式,或者表达式可以直接出来
*/ */
class ConjunctionSimplificationRule : public ExpressionRewriteRule class ConjunctionSimplificationRule : public ExpressionRewriteRule {
{
public: public:
ConjunctionSimplificationRule() = default; ConjunctionSimplificationRule() = default;
virtual ~ConjunctionSimplificationRule() = default; virtual ~ConjunctionSimplificationRule() = default;
RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) override; RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) override;
private: private:
}; };
...@@ -61,9 +61,9 @@ RC ExpressionRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &cha ...@@ -61,9 +61,9 @@ RC ExpressionRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &cha
RC ExpressionRewriter::rewrite_expression(std::unique_ptr<Expression> &expr, bool &change_made) RC ExpressionRewriter::rewrite_expression(std::unique_ptr<Expression> &expr, bool &change_made)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
change_made = false; change_made = false;
for (std::unique_ptr<ExpressionRewriteRule> & rule: expr_rewrite_rules_) { for (std::unique_ptr<ExpressionRewriteRule> &rule : expr_rewrite_rules_) {
bool sub_change_made = false; bool sub_change_made = false;
rc = rule->rewrite(expr, sub_change_made); rc = rule->rewrite(expr, sub_change_made);
if (sub_change_made && !change_made) { if (sub_change_made && !change_made) {
......
...@@ -21,9 +21,8 @@ See the Mulan PSL v2 for more details. */ ...@@ -21,9 +21,8 @@ See the Mulan PSL v2 for more details. */
#include "sql/expr/expression.h" #include "sql/expr/expression.h"
#include "sql/optimizer/rewrite_rule.h" #include "sql/optimizer/rewrite_rule.h"
class ExpressionRewriter : public RewriteRule class ExpressionRewriter : public RewriteRule {
{ public:
public:
ExpressionRewriter(); ExpressionRewriter();
virtual ~ExpressionRewriter() = default; virtual ~ExpressionRewriter() = default;
...@@ -31,7 +30,6 @@ public: ...@@ -31,7 +30,6 @@ public:
private: private:
RC rewrite_expression(std::unique_ptr<Expression> &expr, bool &change_made); RC rewrite_expression(std::unique_ptr<Expression> &expr, bool &change_made);
private: private:
std::vector<std::unique_ptr<ExpressionRewriteRule>> expr_rewrite_rules_; std::vector<std::unique_ptr<ExpressionRewriteRule>> expr_rewrite_rules_;
......
...@@ -100,18 +100,17 @@ void OptimizeStage::cleanup() ...@@ -100,18 +100,17 @@ void OptimizeStage::cleanup()
void OptimizeStage::handle_event(StageEvent *event) void OptimizeStage::handle_event(StageEvent *event)
{ {
LOG_TRACE("Enter"); LOG_TRACE("Enter");
SQLStageEvent *sql_event = static_cast<SQLStageEvent*>(event); SQLStageEvent *sql_event = static_cast<SQLStageEvent *>(event);
RC rc = handle_request(sql_event); RC rc = handle_request(sql_event);
if (rc != RC::UNIMPLENMENT && rc != RC::SUCCESS) { if (rc != RC::UNIMPLENMENT && rc != RC::SUCCESS) {
SqlResult *sql_result = new SqlResult; SqlResult *sql_result = new SqlResult;
sql_result->set_return_code(rc); sql_result->set_return_code(rc);
sql_event->session_event()->set_sql_result(sql_result); sql_event->session_event()->set_sql_result(sql_result);
} else { } else {
execute_stage_->handle_event(event); execute_stage_->handle_event(event);
} }
LOG_TRACE("Exit"); LOG_TRACE("Exit");
} }
RC OptimizeStage::handle_request(SQLStageEvent *sql_event) RC OptimizeStage::handle_request(SQLStageEvent *sql_event)
{ {
...@@ -123,7 +122,7 @@ RC OptimizeStage::handle_request(SQLStageEvent *sql_event) ...@@ -123,7 +122,7 @@ RC OptimizeStage::handle_request(SQLStageEvent *sql_event)
} }
return rc; return rc;
} }
rc = rewrite(logical_operator); rc = rewrite(logical_operator);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("failed to rewrite plan. rc=%s", strrc(rc)); LOG_WARN("failed to rewrite plan. rc=%s", strrc(rc));
...@@ -154,8 +153,8 @@ RC OptimizeStage::optimize(std::unique_ptr<LogicalOperator> &oper) ...@@ -154,8 +153,8 @@ RC OptimizeStage::optimize(std::unique_ptr<LogicalOperator> &oper)
return RC::SUCCESS; return RC::SUCCESS;
} }
RC OptimizeStage::generate_physical_plan(std::unique_ptr<LogicalOperator> &logical_operator, RC OptimizeStage::generate_physical_plan(
std::unique_ptr<PhysicalOperator> &physical_operator) std::unique_ptr<LogicalOperator> &logical_operator, std::unique_ptr<PhysicalOperator> &physical_operator)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
rc = physical_plan_generator_.create(*logical_operator, physical_operator); rc = physical_plan_generator_.create(*logical_operator, physical_operator);
...@@ -212,7 +211,7 @@ RC OptimizeStage::create_logical_plan(Stmt *stmt, std::unique_ptr<LogicalOperato ...@@ -212,7 +211,7 @@ RC OptimizeStage::create_logical_plan(Stmt *stmt, std::unique_ptr<LogicalOperato
} }
return rc; return rc;
} }
RC OptimizeStage::create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr<LogicalOperator> & logical_operator) RC OptimizeStage::create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr<LogicalOperator> &logical_operator)
{ {
Stmt *stmt = sql_event->stmt(); Stmt *stmt = sql_event->stmt();
if (nullptr == stmt) { if (nullptr == stmt) {
...@@ -222,7 +221,8 @@ RC OptimizeStage::create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr< ...@@ -222,7 +221,8 @@ RC OptimizeStage::create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr<
return create_logical_plan(stmt, logical_operator); return create_logical_plan(stmt, logical_operator);
} }
RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::unique_ptr<LogicalOperator> & logical_operator) RC OptimizeStage::create_select_logical_plan(
SelectStmt *select_stmt, std::unique_ptr<LogicalOperator> &logical_operator)
{ {
std::unique_ptr<LogicalOperator> table_oper(nullptr); std::unique_ptr<LogicalOperator> table_oper(nullptr);
...@@ -235,12 +235,12 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu ...@@ -235,12 +235,12 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu
fields.push_back(field); fields.push_back(field);
} }
} }
std::unique_ptr<LogicalOperator> table_get_oper(new TableGetLogicalOperator(table, fields)); std::unique_ptr<LogicalOperator> table_get_oper(new TableGetLogicalOperator(table, fields));
if (table_oper == nullptr) { if (table_oper == nullptr) {
table_oper = std::move(table_get_oper); table_oper = std::move(table_get_oper);
} else { } else {
JoinLogicalOperator * join_oper = new JoinLogicalOperator; JoinLogicalOperator *join_oper = new JoinLogicalOperator;
join_oper->add_child(std::move(table_oper)); join_oper->add_child(std::move(table_oper));
join_oper->add_child(std::move(table_get_oper)); join_oper->add_child(std::move(table_get_oper));
table_oper = std::unique_ptr<LogicalOperator>(join_oper); table_oper = std::unique_ptr<LogicalOperator>(join_oper);
...@@ -253,7 +253,7 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu ...@@ -253,7 +253,7 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu
LOG_WARN("failed to create predicate logical plan. rc=%s", strrc(rc)); LOG_WARN("failed to create predicate logical plan. rc=%s", strrc(rc));
return rc; return rc;
} }
std::unique_ptr<LogicalOperator> project_oper(new ProjectLogicalOperator(all_fields)); std::unique_ptr<LogicalOperator> project_oper(new ProjectLogicalOperator(all_fields));
if (predicate_oper) { if (predicate_oper) {
predicate_oper->add_child(move(table_oper)); predicate_oper->add_child(move(table_oper));
...@@ -266,24 +266,23 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu ...@@ -266,24 +266,23 @@ RC OptimizeStage::create_select_logical_plan(SelectStmt *select_stmt, std::uniqu
return RC::SUCCESS; return RC::SUCCESS;
} }
RC OptimizeStage::create_predicate_logical_plan(FilterStmt *filter_stmt, std::unique_ptr<LogicalOperator> &logical_operator) RC OptimizeStage::create_predicate_logical_plan(
FilterStmt *filter_stmt, std::unique_ptr<LogicalOperator> &logical_operator)
{ {
std::vector<std::unique_ptr<Expression>> cmp_exprs; std::vector<std::unique_ptr<Expression>> cmp_exprs;
const std::vector<FilterUnit *> &filter_units = filter_stmt->filter_units(); const std::vector<FilterUnit *> &filter_units = filter_stmt->filter_units();
for (const FilterUnit * filter_unit : filter_units) { for (const FilterUnit *filter_unit : filter_units) {
const FilterObj &filter_obj_left = filter_unit->left(); const FilterObj &filter_obj_left = filter_unit->left();
const FilterObj &filter_obj_right = filter_unit->right(); const FilterObj &filter_obj_right = filter_unit->right();
std::unique_ptr<Expression> left( std::unique_ptr<Expression> left(filter_obj_left.is_attr
filter_obj_left.is_attr ? ? static_cast<Expression *>(new FieldExpr(filter_obj_left.field))
static_cast<Expression *>(new FieldExpr(filter_obj_left.field)) : : static_cast<Expression *>(new ValueExpr(filter_obj_left.value)));
static_cast<Expression *>(new ValueExpr(filter_obj_left.value)));
std::unique_ptr<Expression> right(filter_obj_right.is_attr
? static_cast<Expression *>(new FieldExpr(filter_obj_right.field))
: static_cast<Expression *>(new ValueExpr(filter_obj_right.value)));
std::unique_ptr<Expression> right(
filter_obj_right.is_attr ?
static_cast<Expression *>(new FieldExpr(filter_obj_right.field)) :
static_cast<Expression *>(new ValueExpr(filter_obj_right.value)));
ComparisonExpr *cmp_expr = new ComparisonExpr(filter_unit->comp(), std::move(left), std::move(right)); ComparisonExpr *cmp_expr = new ComparisonExpr(filter_unit->comp(), std::move(left), std::move(right));
cmp_exprs.emplace_back(cmp_expr); cmp_exprs.emplace_back(cmp_expr);
} }
...@@ -298,7 +297,8 @@ RC OptimizeStage::create_predicate_logical_plan(FilterStmt *filter_stmt, std::un ...@@ -298,7 +297,8 @@ RC OptimizeStage::create_predicate_logical_plan(FilterStmt *filter_stmt, std::un
return RC::SUCCESS; return RC::SUCCESS;
} }
RC OptimizeStage::create_delete_logical_plan(DeleteStmt *delete_stmt, std::unique_ptr<LogicalOperator> &logical_operator) RC OptimizeStage::create_delete_logical_plan(
DeleteStmt *delete_stmt, std::unique_ptr<LogicalOperator> &logical_operator)
{ {
Table *table = delete_stmt->table(); Table *table = delete_stmt->table();
FilterStmt *filter_stmt = delete_stmt->filter_stmt(); FilterStmt *filter_stmt = delete_stmt->filter_stmt();
...@@ -314,9 +314,9 @@ RC OptimizeStage::create_delete_logical_plan(DeleteStmt *delete_stmt, std::uniqu ...@@ -314,9 +314,9 @@ RC OptimizeStage::create_delete_logical_plan(DeleteStmt *delete_stmt, std::uniqu
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
return rc; return rc;
} }
std::unique_ptr<LogicalOperator> delete_oper(new DeleteLogicalOperator(table)); std::unique_ptr<LogicalOperator> delete_oper(new DeleteLogicalOperator(table));
if (predicate_oper) { if (predicate_oper) {
predicate_oper->add_child(move(table_get_oper)); predicate_oper->add_child(move(table_get_oper));
delete_oper->add_child(move(predicate_oper)); delete_oper->add_child(move(predicate_oper));
...@@ -328,7 +328,8 @@ RC OptimizeStage::create_delete_logical_plan(DeleteStmt *delete_stmt, std::uniqu ...@@ -328,7 +328,8 @@ RC OptimizeStage::create_delete_logical_plan(DeleteStmt *delete_stmt, std::uniqu
return rc; return rc;
} }
RC OptimizeStage::create_explain_logical_plan(ExplainStmt *explain_stmt, std::unique_ptr<LogicalOperator> &logical_operator) RC OptimizeStage::create_explain_logical_plan(
ExplainStmt *explain_stmt, std::unique_ptr<LogicalOperator> &logical_operator)
{ {
Stmt *child_stmt = explain_stmt->child(); Stmt *child_stmt = explain_stmt->child();
std::unique_ptr<LogicalOperator> child_oper; std::unique_ptr<LogicalOperator> child_oper;
......
...@@ -48,20 +48,21 @@ protected: ...@@ -48,20 +48,21 @@ protected:
private: private:
RC handle_request(SQLStageEvent *event); RC handle_request(SQLStageEvent *event);
RC create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr<LogicalOperator> & logical_operator); RC create_logical_plan(SQLStageEvent *sql_event, std::unique_ptr<LogicalOperator> &logical_operator);
RC create_logical_plan(Stmt *stmt, std::unique_ptr<LogicalOperator> &logical_operator); RC create_logical_plan(Stmt *stmt, std::unique_ptr<LogicalOperator> &logical_operator);
RC create_select_logical_plan(SelectStmt *select_stmt, std::unique_ptr<LogicalOperator> & logical_operator); RC create_select_logical_plan(SelectStmt *select_stmt, std::unique_ptr<LogicalOperator> &logical_operator);
RC create_predicate_logical_plan(FilterStmt *filter_stmt, std::unique_ptr<LogicalOperator> &logical_operator); RC create_predicate_logical_plan(FilterStmt *filter_stmt, std::unique_ptr<LogicalOperator> &logical_operator);
RC create_delete_logical_plan(DeleteStmt *delete_stmt, std::unique_ptr<LogicalOperator> &logical_operator); RC create_delete_logical_plan(DeleteStmt *delete_stmt, std::unique_ptr<LogicalOperator> &logical_operator);
RC create_explain_logical_plan(ExplainStmt *explain_stmt, std::unique_ptr<LogicalOperator> &logical_operator); RC create_explain_logical_plan(ExplainStmt *explain_stmt, std::unique_ptr<LogicalOperator> &logical_operator);
RC rewrite(std::unique_ptr<LogicalOperator> &logical_operator); RC rewrite(std::unique_ptr<LogicalOperator> &logical_operator);
RC optimize(std::unique_ptr<LogicalOperator> &logical_operator); RC optimize(std::unique_ptr<LogicalOperator> &logical_operator);
RC generate_physical_plan(std::unique_ptr<LogicalOperator> &logical_operator, std::unique_ptr<PhysicalOperator> &physical_operator); RC generate_physical_plan(
std::unique_ptr<LogicalOperator> &logical_operator, std::unique_ptr<PhysicalOperator> &physical_operator);
private: private:
Stage *execute_stage_ = nullptr; Stage *execute_stage_ = nullptr;
PhysicalPlanGenerator physical_plan_generator_; PhysicalPlanGenerator physical_plan_generator_;
Rewriter rewriter_; Rewriter rewriter_;
}; };
...@@ -59,7 +59,7 @@ RC PhysicalPlanGenerator::create(LogicalOperator &logical_operator, std::unique_ ...@@ -59,7 +59,7 @@ RC PhysicalPlanGenerator::create(LogicalOperator &logical_operator, std::unique_
case LogicalOperatorType::JOIN: { case LogicalOperatorType::JOIN: {
return create_plan(static_cast<JoinLogicalOperator &>(logical_operator), oper); return create_plan(static_cast<JoinLogicalOperator &>(logical_operator), oper);
} break; } break;
default: { default: {
return RC::INVALID_ARGUMENT; return RC::INVALID_ARGUMENT;
} }
...@@ -117,10 +117,8 @@ RC PhysicalPlanGenerator::create_plan(TableGetLogicalOperator &table_get_oper, s ...@@ -117,10 +117,8 @@ RC PhysicalPlanGenerator::create_plan(TableGetLogicalOperator &table_get_oper, s
ASSERT(value_expr != nullptr, "got an index but value expr is null ?"); ASSERT(value_expr != nullptr, "got an index but value expr is null ?");
const TupleCell &tuple_cell = value_expr->get_tuple_cell(); const TupleCell &tuple_cell = value_expr->get_tuple_cell();
IndexScanPhysicalOperator *index_scan_oper = IndexScanPhysicalOperator *index_scan_oper = new IndexScanPhysicalOperator(
new IndexScanPhysicalOperator(table, index, table, index, &tuple_cell, true /*left_inclusive*/, &tuple_cell, true /*right_inclusive*/);
&tuple_cell, true/*left_inclusive*/,
&tuple_cell, true /*right_inclusive*/);
index_scan_oper->set_predicates(std::move(predicates)); index_scan_oper->set_predicates(std::move(predicates));
oper = std::unique_ptr<PhysicalOperator>(index_scan_oper); oper = std::unique_ptr<PhysicalOperator>(index_scan_oper);
LOG_TRACE("use index scan"); LOG_TRACE("use index scan");
...@@ -130,7 +128,7 @@ RC PhysicalPlanGenerator::create_plan(TableGetLogicalOperator &table_get_oper, s ...@@ -130,7 +128,7 @@ RC PhysicalPlanGenerator::create_plan(TableGetLogicalOperator &table_get_oper, s
oper = std::unique_ptr<PhysicalOperator>(table_scan_oper); oper = std::unique_ptr<PhysicalOperator>(table_scan_oper);
LOG_TRACE("use table scan"); LOG_TRACE("use table scan");
} }
return RC::SUCCESS; return RC::SUCCESS;
} }
...@@ -147,7 +145,7 @@ RC PhysicalPlanGenerator::create_plan(PredicateLogicalOperator &pred_oper, std:: ...@@ -147,7 +145,7 @@ RC PhysicalPlanGenerator::create_plan(PredicateLogicalOperator &pred_oper, std::
LOG_WARN("failed to create child operator of predicate operator. rc=%s", strrc(rc)); LOG_WARN("failed to create child operator of predicate operator. rc=%s", strrc(rc));
return rc; return rc;
} }
std::vector<std::unique_ptr<Expression>> &expressions = pred_oper.expressions(); std::vector<std::unique_ptr<Expression>> &expressions = pred_oper.expressions();
ASSERT(expressions.size() == 1, "predicate logical operator's children should be 1"); ASSERT(expressions.size() == 1, "predicate logical operator's children should be 1");
...@@ -160,7 +158,7 @@ RC PhysicalPlanGenerator::create_plan(PredicateLogicalOperator &pred_oper, std:: ...@@ -160,7 +158,7 @@ RC PhysicalPlanGenerator::create_plan(PredicateLogicalOperator &pred_oper, std::
RC PhysicalPlanGenerator::create_plan(ProjectLogicalOperator &project_oper, std::unique_ptr<PhysicalOperator> &oper) RC PhysicalPlanGenerator::create_plan(ProjectLogicalOperator &project_oper, std::unique_ptr<PhysicalOperator> &oper)
{ {
std::vector<std::unique_ptr<LogicalOperator>> &child_opers = project_oper.children(); std::vector<std::unique_ptr<LogicalOperator>> &child_opers = project_oper.children();
std::unique_ptr<PhysicalOperator> child_phy_oper; std::unique_ptr<PhysicalOperator> child_phy_oper;
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
...@@ -175,7 +173,7 @@ RC PhysicalPlanGenerator::create_plan(ProjectLogicalOperator &project_oper, std: ...@@ -175,7 +173,7 @@ RC PhysicalPlanGenerator::create_plan(ProjectLogicalOperator &project_oper, std:
ProjectPhysicalOperator *project_operator = new ProjectPhysicalOperator; ProjectPhysicalOperator *project_operator = new ProjectPhysicalOperator;
const std::vector<Field> &project_fields = project_oper.fields(); const std::vector<Field> &project_fields = project_oper.fields();
for (const Field & field : project_fields) { for (const Field &field : project_fields) {
project_operator->add_projection(field.table(), field.meta()); project_operator->add_projection(field.table(), field.meta());
} }
......
...@@ -27,16 +27,14 @@ class DeleteLogicalOperator; ...@@ -27,16 +27,14 @@ class DeleteLogicalOperator;
class ExplainLogicalOperator; class ExplainLogicalOperator;
class JoinLogicalOperator; class JoinLogicalOperator;
class PhysicalPlanGenerator class PhysicalPlanGenerator {
{ public:
public:
PhysicalPlanGenerator() = default; PhysicalPlanGenerator() = default;
virtual ~PhysicalPlanGenerator() = default; virtual ~PhysicalPlanGenerator() = default;
RC create(LogicalOperator &logical_operator, std::unique_ptr<PhysicalOperator> &oper); RC create(LogicalOperator &logical_operator, std::unique_ptr<PhysicalOperator> &oper);
private:
private:
RC create_plan(TableGetLogicalOperator &table_get_oper, std::unique_ptr<PhysicalOperator> &oper); RC create_plan(TableGetLogicalOperator &table_get_oper, std::unique_ptr<PhysicalOperator> &oper);
RC create_plan(PredicateLogicalOperator &pred_oper, std::unique_ptr<PhysicalOperator> &oper); RC create_plan(PredicateLogicalOperator &pred_oper, std::unique_ptr<PhysicalOperator> &oper);
RC create_plan(ProjectLogicalOperator &project_oper, std::unique_ptr<PhysicalOperator> &oper); RC create_plan(ProjectLogicalOperator &project_oper, std::unique_ptr<PhysicalOperator> &oper);
......
...@@ -52,7 +52,7 @@ RC PredicatePushdownRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bo ...@@ -52,7 +52,7 @@ RC PredicatePushdownRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bo
// 所有的表达式都下推到了下层算子 // 所有的表达式都下推到了下层算子
// 这个predicate operator其实就可以不要了。但是这里没办法删除,弄一个空的表达式吧 // 这个predicate operator其实就可以不要了。但是这里没办法删除,弄一个空的表达式吧
LOG_TRACE("all expressions of predicate operator were pushdown to table get operator, then make a fake one"); LOG_TRACE("all expressions of predicate operator were pushdown to table get operator, then make a fake one");
Value value; Value value;
value.type = BOOLEANS; value.type = BOOLEANS;
value.bool_value = true; value.bool_value = true;
...@@ -73,8 +73,7 @@ RC PredicatePushdownRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bo ...@@ -73,8 +73,7 @@ RC PredicatePushdownRewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bo
* pushdown_exprs 只会增加,不要做清理操作 * pushdown_exprs 只会增加,不要做清理操作
*/ */
RC PredicatePushdownRewriter::get_exprs_can_pushdown( RC PredicatePushdownRewriter::get_exprs_can_pushdown(
std::unique_ptr<Expression> &expr, std::unique_ptr<Expression> &expr, std::vector<std::unique_ptr<Expression>> &pushdown_exprs)
std::vector<std::unique_ptr<Expression>> &pushdown_exprs)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
if (expr->type() == ExprType::CONJUNCTION) { if (expr->type() == ExprType::CONJUNCTION) {
...@@ -85,7 +84,7 @@ RC PredicatePushdownRewriter::get_exprs_can_pushdown( ...@@ -85,7 +84,7 @@ RC PredicatePushdownRewriter::get_exprs_can_pushdown(
} }
std::vector<std::unique_ptr<Expression>> &child_exprs = conjunction_expr->children(); std::vector<std::unique_ptr<Expression>> &child_exprs = conjunction_expr->children();
for (auto iter = child_exprs.begin(); iter != child_exprs.end(); ) { for (auto iter = child_exprs.begin(); iter != child_exprs.end();) {
// 对每个子表达式,判断是否可以下放到table get 算子 // 对每个子表达式,判断是否可以下放到table get 算子
// 如果可以的话,就从当前孩子节点中删除他 // 如果可以的话,就从当前孩子节点中删除他
rc = get_exprs_can_pushdown(*iter, pushdown_exprs); rc = get_exprs_can_pushdown(*iter, pushdown_exprs);
......
...@@ -21,8 +21,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -21,8 +21,7 @@ See the Mulan PSL v2 for more details. */
* 将一些谓词表达式下推到表数据扫描中 * 将一些谓词表达式下推到表数据扫描中
* 这样可以提前过滤一些数据 * 这样可以提前过滤一些数据
*/ */
class PredicatePushdownRewriter : public RewriteRule class PredicatePushdownRewriter : public RewriteRule {
{
public: public:
PredicatePushdownRewriter() = default; PredicatePushdownRewriter() = default;
virtual ~PredicatePushdownRewriter() = default; virtual ~PredicatePushdownRewriter() = default;
...@@ -30,6 +29,6 @@ public: ...@@ -30,6 +29,6 @@ public:
RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) override; RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) override;
private: private:
RC get_exprs_can_pushdown(std::unique_ptr<Expression> &expr, RC get_exprs_can_pushdown(
std::vector<std::unique_ptr<Expression>> &pushdown_exprs); std::unique_ptr<Expression> &expr, std::vector<std::unique_ptr<Expression>> &pushdown_exprs);
}; };
...@@ -16,8 +16,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,8 +16,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/optimizer/rewrite_rule.h" #include "sql/optimizer/rewrite_rule.h"
class PredicateRewriteRule : public RewriteRule class PredicateRewriteRule : public RewriteRule {
{
public: public:
PredicateRewriteRule() = default; PredicateRewriteRule() = default;
virtual ~PredicateRewriteRule() = default; virtual ~PredicateRewriteRule() = default;
......
...@@ -21,18 +21,16 @@ See the Mulan PSL v2 for more details. */ ...@@ -21,18 +21,16 @@ See the Mulan PSL v2 for more details. */
class LogicalOperator; class LogicalOperator;
class Expression; class Expression;
class RewriteRule class RewriteRule {
{ public:
public:
virtual ~RewriteRule() = default; virtual ~RewriteRule() = default;
virtual RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) = 0; virtual RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) = 0;
}; };
class ExpressionRewriteRule class ExpressionRewriteRule {
{
public: public:
virtual ~ExpressionRewriteRule() = default; virtual ~ExpressionRewriteRule() = default;
virtual RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) = 0; virtual RC rewrite(std::unique_ptr<Expression> &expr, bool &change_made) = 0;
}; };
...@@ -30,7 +30,7 @@ RC Rewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) ...@@ -30,7 +30,7 @@ RC Rewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made)
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
change_made = false; change_made = false;
for (std::unique_ptr<RewriteRule> & rule : rewrite_rules_) { for (std::unique_ptr<RewriteRule> &rule : rewrite_rules_) {
bool sub_change_made = false; bool sub_change_made = false;
rc = rule->rewrite(oper, sub_change_made); rc = rule->rewrite(oper, sub_change_made);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
...@@ -46,7 +46,7 @@ RC Rewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made) ...@@ -46,7 +46,7 @@ RC Rewriter::rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made)
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
return rc; return rc;
} }
std::vector<std::unique_ptr<LogicalOperator>> &child_opers = oper->children(); std::vector<std::unique_ptr<LogicalOperator>> &child_opers = oper->children();
for (auto &child_oper : child_opers) { for (auto &child_oper : child_opers) {
bool sub_change_made = false; bool sub_change_made = false;
......
...@@ -20,12 +20,11 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,12 +20,11 @@ See the Mulan PSL v2 for more details. */
class LogicalOperator; class LogicalOperator;
class Rewriter class Rewriter {
{ public:
public:
Rewriter(); Rewriter();
virtual ~Rewriter() = default; virtual ~Rewriter() = default;
RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made); RC rewrite(std::unique_ptr<LogicalOperator> &oper, bool &change_made);
private: private:
......
...@@ -9,7 +9,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ...@@ -9,7 +9,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details. */ See the Mulan PSL v2 for more details. */
// //
// Created by Meiyi // Created by Meiyi
// //
#include <mutex> #include <mutex>
...@@ -41,11 +41,16 @@ AttrType attr_type_from_string(const char *s) ...@@ -41,11 +41,16 @@ AttrType attr_type_from_string(const char *s)
const char *Value::data() const const char *Value::data() const
{ {
switch (type) { switch (type) {
case INTS: return (const char *)&int_value; case INTS:
case FLOATS: return (const char *)&float_value; return (const char *)&int_value;
case BOOLEANS: return (const char *)&bool_value; case FLOATS:
case CHARS: return (const char *)string_value.data(); return (const char *)&float_value;
case UNDEFINED: return nullptr; case BOOLEANS:
return (const char *)&bool_value;
case CHARS:
return (const char *)string_value.data();
case UNDEFINED:
return nullptr;
} }
return nullptr; return nullptr;
} }
...@@ -53,22 +58,24 @@ const char *Value::data() const ...@@ -53,22 +58,24 @@ const char *Value::data() const
int Value::length() int Value::length()
{ {
switch (type) { switch (type) {
case INTS: return sizeof(int_value); case INTS:
case FLOATS: return sizeof(float_value); return sizeof(int_value);
case BOOLEANS: return sizeof(bool_value); case FLOATS:
case CHARS: return string_value.size(); return sizeof(float_value);
case UNDEFINED: return 0; case BOOLEANS:
return sizeof(bool_value);
case CHARS:
return string_value.size();
case UNDEFINED:
return 0;
} }
return 0; return 0;
} }
Command::Command() Command::Command() : flag(SCF_ERROR)
: flag(SCF_ERROR) {}
{
}
Command::Command(enum SqlCommandFlag _flag) Command::Command(enum SqlCommandFlag _flag) : flag(_flag)
: flag(_flag)
{} {}
void ParsedSqlResult::add_command(std::unique_ptr<Command> command) void ParsedSqlResult::add_command(std::unique_ptr<Command> command)
......
...@@ -25,7 +25,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -25,7 +25,7 @@ See the Mulan PSL v2 for more details. */
#define MAX_ERROR_MESSAGE 20 #define MAX_ERROR_MESSAGE 20
#define MAX_DATA 50 #define MAX_DATA 50
//属性结构体 // 属性结构体
struct RelAttr { struct RelAttr {
std::string relation_name; // relation name (may be NULL) 表名 std::string relation_name; // relation name (may be NULL) 表名
std::string attribute_name; // attribute name 属性名 std::string attribute_name; // attribute name 属性名
...@@ -41,9 +41,8 @@ enum CompOp { ...@@ -41,9 +41,8 @@ enum CompOp {
NO_OP NO_OP
}; };
//属性值类型 // 属性值类型
enum AttrType enum AttrType {
{
UNDEFINED, UNDEFINED,
CHARS, CHARS,
INTS, INTS,
...@@ -51,12 +50,12 @@ enum AttrType ...@@ -51,12 +50,12 @@ enum AttrType
BOOLEANS, BOOLEANS,
}; };
//属性值 // 属性值
struct Value { struct Value {
AttrType type; // type of value AttrType type; // type of value
int int_value; int int_value;
float float_value; float float_value;
bool bool_value; bool bool_value;
std::string string_value; std::string string_value;
const char *data() const; const char *data() const;
...@@ -77,42 +76,41 @@ struct Condition { ...@@ -77,42 +76,41 @@ struct Condition {
// struct of select // struct of select
struct Selects { struct Selects {
std::vector<RelAttr> attributes; // attributes in select clause std::vector<RelAttr> attributes; // attributes in select clause
std::vector<std::string> relations; std::vector<std::string> relations;
std::vector<Condition> conditions; std::vector<Condition> conditions;
}; };
// struct of insert // struct of insert
struct Inserts { struct Inserts {
std::string relation_name; // Relation to insert into std::string relation_name; // Relation to insert into
std::vector<Value> values; std::vector<Value> values;
}; };
// struct of delete // struct of delete
struct Deletes { struct Deletes {
std::string relation_name; // Relation to delete from std::string relation_name; // Relation to delete from
std::vector<Condition> conditions; std::vector<Condition> conditions;
}; };
// struct of update // struct of update
struct Updates { struct Updates {
std::string relation_name; // Relation to update std::string relation_name; // Relation to update
std::string attribute_name; // Attribute to update std::string attribute_name; // Attribute to update
Value value; // update value Value value; // update value
std::vector<Condition> conditions; std::vector<Condition> conditions;
}; };
struct AttrInfo struct AttrInfo {
{ AttrType type; // Type of attribute
AttrType type; // Type of attribute std::string name; // Attribute name
std::string name; // Attribute name size_t length; // Length of attribute
size_t length; // Length of attribute
}; };
// struct of craete_table // struct of craete_table
struct CreateTable { struct CreateTable {
std::string relation_name; // Relation name std::string relation_name; // Relation name
std::vector<AttrInfo> attr_infos; // attributes std::vector<AttrInfo> attr_infos; // attributes
}; };
// struct of drop_table // struct of drop_table
...@@ -129,8 +127,8 @@ struct CreateIndex { ...@@ -129,8 +127,8 @@ struct CreateIndex {
// struct of drop_index // struct of drop_index
struct DropIndex { struct DropIndex {
std::string index_name; // Index name std::string index_name; // Index name
std::string relation_name; //Relation name std::string relation_name; // Relation name
}; };
struct DescTable { struct DescTable {
...@@ -147,8 +145,7 @@ struct Explain { ...@@ -147,8 +145,7 @@ struct Explain {
std::unique_ptr<Command> cmd; std::unique_ptr<Command> cmd;
}; };
struct Error struct Error {
{
std::string error_msg; std::string error_msg;
int line; int line;
int column; int column;
...@@ -181,18 +178,18 @@ enum SqlCommandFlag { ...@@ -181,18 +178,18 @@ enum SqlCommandFlag {
class Command { class Command {
public: public:
enum SqlCommandFlag flag; enum SqlCommandFlag flag;
Error error; Error error;
Selects selection; Selects selection;
Inserts insertion; Inserts insertion;
Deletes deletion; Deletes deletion;
Updates update; Updates update;
CreateTable create_table; CreateTable create_table;
DropTable drop_table; DropTable drop_table;
CreateIndex create_index; CreateIndex create_index;
DropIndex drop_index; DropIndex drop_index;
DescTable desc_table; DescTable desc_table;
LoadData load_data; LoadData load_data;
Explain explain; Explain explain;
public: public:
Command(); Command();
...@@ -203,12 +200,14 @@ public: ...@@ -203,12 +200,14 @@ public:
* 表示语法解析后的数据 * 表示语法解析后的数据
* 叫ParsedSqlNode 可能会更清晰一点 * 叫ParsedSqlNode 可能会更清晰一点
*/ */
class ParsedSqlResult class ParsedSqlResult {
{
public: public:
void add_command(std::unique_ptr<Command> command); void add_command(std::unique_ptr<Command> command);
std::vector<std::unique_ptr<Command>> &commands() { return sql_commands_; } std::vector<std::unique_ptr<Command>> &commands()
{
return sql_commands_;
}
private: private:
std::vector<std::unique_ptr<Command>> sql_commands_; std::vector<std::unique_ptr<Command>> sql_commands_;
}; };
......
...@@ -137,7 +137,7 @@ RC ParseStage::handle_request(StageEvent *event) ...@@ -137,7 +137,7 @@ RC ParseStage::handle_request(StageEvent *event)
if (parsed_sql_result.commands().size() > 1) { if (parsed_sql_result.commands().size() > 1) {
LOG_WARN("got multi sql commands but only 1 will be handled"); LOG_WARN("got multi sql commands but only 1 will be handled");
} }
std::unique_ptr<Command> cmd = std::move(parsed_sql_result.commands().front()); std::unique_ptr<Command> cmd = std::move(parsed_sql_result.commands().front());
if (cmd->flag == SCF_ERROR) { if (cmd->flag == SCF_ERROR) {
// set error information to event // set error information to event
......
...@@ -98,7 +98,7 @@ void ResolveStage::handle_event(StageEvent *event) ...@@ -98,7 +98,7 @@ void ResolveStage::handle_event(StageEvent *event)
Db *db = session_event->session()->get_current_db(); Db *db = session_event->session()->get_current_db();
if (nullptr == db) { if (nullptr == db) {
LOG_ERROR("cannot current db"); LOG_ERROR("cannot current db");
return ; return;
} }
Command *cmd = sql_event->command().get(); Command *cmd = sql_event->command().get();
......
...@@ -18,8 +18,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -18,8 +18,7 @@ See the Mulan PSL v2 for more details. */
#include "storage/common/db.h" #include "storage/common/db.h"
#include "storage/common/table.h" #include "storage/common/table.h"
DeleteStmt::DeleteStmt(Table *table, FilterStmt *filter_stmt) DeleteStmt::DeleteStmt(Table *table, FilterStmt *filter_stmt) : table_(table), filter_stmt_(filter_stmt)
: table_ (table), filter_stmt_(filter_stmt)
{} {}
DeleteStmt::~DeleteStmt() DeleteStmt::~DeleteStmt()
...@@ -27,15 +26,14 @@ DeleteStmt::~DeleteStmt() ...@@ -27,15 +26,14 @@ DeleteStmt::~DeleteStmt()
if (nullptr != filter_stmt_) { if (nullptr != filter_stmt_) {
delete filter_stmt_; delete filter_stmt_;
filter_stmt_ = nullptr; filter_stmt_ = nullptr;
} }
} }
RC DeleteStmt::create(Db *db, const Deletes &delete_sql, Stmt *&stmt) RC DeleteStmt::create(Db *db, const Deletes &delete_sql, Stmt *&stmt)
{ {
const char *table_name = delete_sql.relation_name.c_str(); const char *table_name = delete_sql.relation_name.c_str();
if (nullptr == db || nullptr == table_name) { if (nullptr == db || nullptr == table_name) {
LOG_WARN("invalid argument. db=%p, table_name=%p", LOG_WARN("invalid argument. db=%p, table_name=%p", db, table_name);
db, table_name);
return RC::INVALID_ARGUMENT; return RC::INVALID_ARGUMENT;
} }
...@@ -50,9 +48,8 @@ RC DeleteStmt::create(Db *db, const Deletes &delete_sql, Stmt *&stmt) ...@@ -50,9 +48,8 @@ RC DeleteStmt::create(Db *db, const Deletes &delete_sql, Stmt *&stmt)
table_map.insert(std::pair<std::string, Table *>(std::string(table_name), table)); table_map.insert(std::pair<std::string, Table *>(std::string(table_name), table));
FilterStmt *filter_stmt = nullptr; FilterStmt *filter_stmt = nullptr;
RC rc = FilterStmt::create(db, table, &table_map, RC rc = FilterStmt::create(
delete_sql.conditions.data(), static_cast<int>(delete_sql.conditions.size()), db, table, &table_map, delete_sql.conditions.data(), static_cast<int>(delete_sql.conditions.size()), filter_stmt);
filter_stmt);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("failed to create filter statement. rc=%d:%s", rc, strrc(rc)); LOG_WARN("failed to create filter statement. rc=%d:%s", rc, strrc(rc));
return rc; return rc;
......
...@@ -21,17 +21,25 @@ See the Mulan PSL v2 for more details. */ ...@@ -21,17 +21,25 @@ See the Mulan PSL v2 for more details. */
class Table; class Table;
class FilterStmt; class FilterStmt;
class DeleteStmt : public Stmt class DeleteStmt : public Stmt {
{
public: public:
DeleteStmt(Table *table, FilterStmt *filter_stmt); DeleteStmt(Table *table, FilterStmt *filter_stmt);
~DeleteStmt() override; ~DeleteStmt() override;
Table *table() const { return table_; } Table *table() const
FilterStmt *filter_stmt() const { return filter_stmt_; } {
return table_;
}
FilterStmt *filter_stmt() const
{
return filter_stmt_;
}
StmtType type() const override
{
return StmtType::DELETE;
}
StmtType type() const override { return StmtType::DELETE; }
public: public:
static RC create(Db *db, const Deletes &delete_sql, Stmt *&stmt); static RC create(Db *db, const Deletes &delete_sql, Stmt *&stmt);
...@@ -39,4 +47,3 @@ private: ...@@ -39,4 +47,3 @@ private:
Table *table_ = nullptr; Table *table_ = nullptr;
FilterStmt *filter_stmt_ = nullptr; FilterStmt *filter_stmt_ = nullptr;
}; };
...@@ -16,11 +16,10 @@ See the Mulan PSL v2 for more details. */ ...@@ -16,11 +16,10 @@ See the Mulan PSL v2 for more details. */
#include "sql/stmt/stmt.h" #include "sql/stmt/stmt.h"
#include "common/log/log.h" #include "common/log/log.h"
ExplainStmt::ExplainStmt(std::unique_ptr<Stmt> child_stmt) ExplainStmt::ExplainStmt(std::unique_ptr<Stmt> child_stmt) : child_stmt_(std::move(child_stmt))
: child_stmt_(std::move(child_stmt))
{} {}
RC ExplainStmt::create(Db *db, const Explain &explain, Stmt *& stmt) RC ExplainStmt::create(Db *db, const Explain &explain, Stmt *&stmt)
{ {
Stmt *child_stmt = nullptr; Stmt *child_stmt = nullptr;
RC rc = Stmt::create_stmt(db, *explain.cmd, child_stmt); RC rc = Stmt::create_stmt(db, *explain.cmd, child_stmt);
......
...@@ -17,18 +17,23 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,18 +17,23 @@ See the Mulan PSL v2 for more details. */
#include <memory> #include <memory>
#include "sql/stmt/stmt.h" #include "sql/stmt/stmt.h"
class ExplainStmt : public Stmt class ExplainStmt : public Stmt {
{
public: public:
ExplainStmt(std::unique_ptr<Stmt> child_stmt); ExplainStmt(std::unique_ptr<Stmt> child_stmt);
virtual ~ExplainStmt() = default; virtual ~ExplainStmt() = default;
StmtType type() const override { return StmtType::EXPLAIN; } StmtType type() const override
{
return StmtType::EXPLAIN;
}
Stmt *child() const { return child_stmt_.get(); } Stmt *child() const
{
return child_stmt_.get();
}
static RC create(Db *db, const Explain &query, Stmt *&stmt);
static RC create(Db *db, const Explain &query, Stmt *& stmt);
private: private:
std::unique_ptr<Stmt> child_stmt_; std::unique_ptr<Stmt> child_stmt_;
}; };
...@@ -28,8 +28,7 @@ FilterStmt::~FilterStmt() ...@@ -28,8 +28,7 @@ FilterStmt::~FilterStmt()
} }
RC FilterStmt::create(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables, RC FilterStmt::create(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables,
const Condition *conditions, int condition_num, const Condition *conditions, int condition_num, FilterStmt *&stmt)
FilterStmt *&stmt)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
stmt = nullptr; stmt = nullptr;
...@@ -51,7 +50,7 @@ RC FilterStmt::create(Db *db, Table *default_table, std::unordered_map<std::stri ...@@ -51,7 +50,7 @@ RC FilterStmt::create(Db *db, Table *default_table, std::unordered_map<std::stri
} }
RC get_table_and_field(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables, RC get_table_and_field(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables,
const RelAttr &attr, Table *&table, const FieldMeta *&field) const RelAttr &attr, Table *&table, const FieldMeta *&field)
{ {
if (common::is_blank(attr.relation_name.c_str())) { if (common::is_blank(attr.relation_name.c_str())) {
table = default_table; table = default_table;
...@@ -79,22 +78,22 @@ RC get_table_and_field(Db *db, Table *default_table, std::unordered_map<std::str ...@@ -79,22 +78,22 @@ RC get_table_and_field(Db *db, Table *default_table, std::unordered_map<std::str
} }
RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables, RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables,
const Condition &condition, FilterUnit *&filter_unit) const Condition &condition, FilterUnit *&filter_unit)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
CompOp comp = condition.comp; CompOp comp = condition.comp;
if (comp < EQUAL_TO || comp >= NO_OP) { if (comp < EQUAL_TO || comp >= NO_OP) {
LOG_WARN("invalid compare operator : %d", comp); LOG_WARN("invalid compare operator : %d", comp);
return RC::INVALID_ARGUMENT; return RC::INVALID_ARGUMENT;
} }
filter_unit = new FilterUnit; filter_unit = new FilterUnit;
if (condition.left_is_attr) { if (condition.left_is_attr) {
Table *table = nullptr; Table *table = nullptr;
const FieldMeta *field = nullptr; const FieldMeta *field = nullptr;
rc = get_table_and_field(db, default_table, tables, condition.left_attr, table, field); rc = get_table_and_field(db, default_table, tables, condition.left_attr, table, field);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("cannot find attr"); LOG_WARN("cannot find attr");
return rc; return rc;
...@@ -111,7 +110,7 @@ RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_m ...@@ -111,7 +110,7 @@ RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_m
if (condition.right_is_attr) { if (condition.right_is_attr) {
Table *table = nullptr; Table *table = nullptr;
const FieldMeta *field = nullptr; const FieldMeta *field = nullptr;
rc = get_table_and_field(db, default_table, tables, condition.right_attr, table, field); rc = get_table_and_field(db, default_table, tables, condition.right_attr, table, field);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("cannot find attr"); LOG_WARN("cannot find attr");
return rc; return rc;
...@@ -125,7 +124,6 @@ RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_m ...@@ -125,7 +124,6 @@ RC FilterStmt::create_filter_unit(Db *db, Table *default_table, std::unordered_m
filter_unit->set_right(filter_obj); filter_unit->set_right(filter_obj);
} }
filter_unit->set_comp(comp); filter_unit->set_comp(comp);
// 检查两个类型是否能够比较 // 检查两个类型是否能够比较
......
...@@ -25,36 +25,37 @@ class Db; ...@@ -25,36 +25,37 @@ class Db;
class Table; class Table;
class FieldMeta; class FieldMeta;
struct FilterObj struct FilterObj {
{
bool is_attr; bool is_attr;
Field field; Field field;
Value value; Value value;
void init_attr(const Field &field) { void init_attr(const Field &field)
{
is_attr = true; is_attr = true;
this->field = field; this->field = field;
} }
void init_value(const Value &value) { void init_value(const Value &value)
{
is_attr = false; is_attr = false;
this->value = value; this->value = value;
} }
}; };
class FilterUnit class FilterUnit {
{
public: public:
FilterUnit() = default; FilterUnit() = default;
~FilterUnit() ~FilterUnit()
{}
void set_comp(CompOp comp)
{ {
}
void set_comp(CompOp comp) {
comp_ = comp; comp_ = comp;
} }
CompOp comp() const { CompOp comp() const
{
return comp_; return comp_;
} }
...@@ -82,10 +83,8 @@ private: ...@@ -82,10 +83,8 @@ private:
FilterObj right_; FilterObj right_;
}; };
class FilterStmt class FilterStmt {
{
public: public:
FilterStmt() = default; FilterStmt() = default;
virtual ~FilterStmt(); virtual ~FilterStmt();
...@@ -97,12 +96,11 @@ public: ...@@ -97,12 +96,11 @@ public:
public: public:
static RC create(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables, static RC create(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables,
const Condition *conditions, int condition_num, const Condition *conditions, int condition_num, FilterStmt *&stmt);
FilterStmt *&stmt);
static RC create_filter_unit(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables, static RC create_filter_unit(Db *db, Table *default_table, std::unordered_map<std::string, Table *> *tables,
const Condition &condition, FilterUnit *&filter_unit); const Condition &condition, FilterUnit *&filter_unit);
private: private:
std::vector<FilterUnit *> filter_units_; // 默认当前都是AND关系 std::vector<FilterUnit *> filter_units_; // 默认当前都是AND关系
}; };
...@@ -18,15 +18,17 @@ See the Mulan PSL v2 for more details. */ ...@@ -18,15 +18,17 @@ See the Mulan PSL v2 for more details. */
#include "storage/common/table.h" #include "storage/common/table.h"
InsertStmt::InsertStmt(Table *table, const Value *values, int value_amount) InsertStmt::InsertStmt(Table *table, const Value *values, int value_amount)
: table_ (table), values_(values), value_amount_(value_amount) : table_(table), values_(values), value_amount_(value_amount)
{} {}
RC InsertStmt::create(Db *db, const Inserts &inserts, Stmt *&stmt) RC InsertStmt::create(Db *db, const Inserts &inserts, Stmt *&stmt)
{ {
const char *table_name = inserts.relation_name.c_str(); const char *table_name = inserts.relation_name.c_str();
if (nullptr == db || nullptr == table_name || inserts.values.empty()) { if (nullptr == db || nullptr == table_name || inserts.values.empty()) {
LOG_WARN("invalid argument. db=%p, table_name=%p, value_num=%d", LOG_WARN("invalid argument. db=%p, table_name=%p, value_num=%d",
db, table_name, static_cast<int>(inserts.values.size())); db,
table_name,
static_cast<int>(inserts.values.size()));
return RC::INVALID_ARGUMENT; return RC::INVALID_ARGUMENT;
} }
...@@ -53,9 +55,12 @@ RC InsertStmt::create(Db *db, const Inserts &inserts, Stmt *&stmt) ...@@ -53,9 +55,12 @@ RC InsertStmt::create(Db *db, const Inserts &inserts, Stmt *&stmt)
const FieldMeta *field_meta = table_meta.field(i + sys_field_num); const FieldMeta *field_meta = table_meta.field(i + sys_field_num);
const AttrType field_type = field_meta->type(); const AttrType field_type = field_meta->type();
const AttrType value_type = values[i].type; const AttrType value_type = values[i].type;
if (field_type != value_type) { // TODO try to convert the value type to field type if (field_type != value_type) { // TODO try to convert the value type to field type
LOG_WARN("field type mismatch. table=%s, field=%s, field type=%d, value_type=%d", LOG_WARN("field type mismatch. table=%s, field=%s, field type=%d, value_type=%d",
table_name, field_meta->name(), field_type, value_type); table_name,
field_meta->name(),
field_type,
value_type);
return RC::SCHEMA_FIELD_TYPE_MISMATCH; return RC::SCHEMA_FIELD_TYPE_MISMATCH;
} }
} }
......
...@@ -20,27 +20,35 @@ See the Mulan PSL v2 for more details. */ ...@@ -20,27 +20,35 @@ See the Mulan PSL v2 for more details. */
class Table; class Table;
class Db; class Db;
class InsertStmt : public Stmt class InsertStmt : public Stmt {
{
public: public:
InsertStmt() = default; InsertStmt() = default;
InsertStmt(Table *table, const Value *values, int value_amount); InsertStmt(Table *table, const Value *values, int value_amount);
StmtType type() const override { StmtType type() const override
{
return StmtType::INSERT; return StmtType::INSERT;
} }
public: public:
static RC create(Db *db, const Inserts &insert_sql, Stmt *&stmt); static RC create(Db *db, const Inserts &insert_sql, Stmt *&stmt);
public: public:
Table *table() const {return table_;} Table *table() const
const Value *values() const { return values_; } {
int value_amount() const { return value_amount_; } return table_;
}
const Value *values() const
{
return values_;
}
int value_amount() const
{
return value_amount_;
}
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
const Value *values_ = nullptr; const Value *values_ = nullptr;
int value_amount_ = 0; int value_amount_ = 0;
}; };
...@@ -60,15 +60,16 @@ RC SelectStmt::create(Db *db, const Selects &select_sql, Stmt *&stmt) ...@@ -60,15 +60,16 @@ RC SelectStmt::create(Db *db, const Selects &select_sql, Stmt *&stmt)
} }
tables.push_back(table); tables.push_back(table);
table_map.insert(std::pair<std::string, Table*>(table_name, table)); table_map.insert(std::pair<std::string, Table *>(table_name, table));
} }
// collect query fields in `select` statement // collect query fields in `select` statement
std::vector<Field> query_fields; std::vector<Field> query_fields;
for (int i = static_cast<int>(select_sql.attributes.size()) - 1; i >= 0; i--) { for (int i = static_cast<int>(select_sql.attributes.size()) - 1; i >= 0; i--) {
const RelAttr &relation_attr = select_sql.attributes[i]; const RelAttr &relation_attr = select_sql.attributes[i];
if (common::is_blank(relation_attr.relation_name.c_str()) && 0 == strcmp(relation_attr.attribute_name.c_str(), "*")) { if (common::is_blank(relation_attr.relation_name.c_str()) &&
0 == strcmp(relation_attr.attribute_name.c_str(), "*")) {
for (Table *table : tables) { for (Table *table : tables) {
wildcard_fields(table, query_fields); wildcard_fields(table, query_fields);
} }
...@@ -131,8 +132,12 @@ RC SelectStmt::create(Db *db, const Selects &select_sql, Stmt *&stmt) ...@@ -131,8 +132,12 @@ RC SelectStmt::create(Db *db, const Selects &select_sql, Stmt *&stmt)
// create filter statement in `where` statement // create filter statement in `where` statement
FilterStmt *filter_stmt = nullptr; FilterStmt *filter_stmt = nullptr;
RC rc = FilterStmt::create(db, default_table, &table_map, RC rc = FilterStmt::create(db,
select_sql.conditions.data(), static_cast<int>(select_sql.conditions.size()), filter_stmt); default_table,
&table_map,
select_sql.conditions.data(),
static_cast<int>(select_sql.conditions.size()),
filter_stmt);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_WARN("cannot construct filter stmt"); LOG_WARN("cannot construct filter stmt");
return rc; return rc;
......
...@@ -25,25 +25,35 @@ class FilterStmt; ...@@ -25,25 +25,35 @@ class FilterStmt;
class Db; class Db;
class Table; class Table;
class SelectStmt : public Stmt class SelectStmt : public Stmt {
{
public: public:
SelectStmt() = default; SelectStmt() = default;
~SelectStmt() override; ~SelectStmt() override;
StmtType type() const override { return StmtType::SELECT; } StmtType type() const override
{
return StmtType::SELECT;
}
public: public:
static RC create(Db *db, const Selects &select_sql, Stmt *&stmt); static RC create(Db *db, const Selects &select_sql, Stmt *&stmt);
public: public:
const std::vector<Table *> &tables() const { return tables_; } const std::vector<Table *> &tables() const
const std::vector<Field> &query_fields() const { return query_fields_; } {
FilterStmt *filter_stmt() const { return filter_stmt_; } return tables_;
}
const std::vector<Field> &query_fields() const
{
return query_fields_;
}
FilterStmt *filter_stmt() const
{
return filter_stmt_;
}
private: private:
std::vector<Field> query_fields_; std::vector<Field> query_fields_;
std::vector<Table *> tables_; std::vector<Table *> tables_;
FilterStmt *filter_stmt_ = nullptr; FilterStmt *filter_stmt_ = nullptr;
}; };
...@@ -39,10 +39,7 @@ RC Stmt::create_stmt(Db *db, const Command &cmd, Stmt *&stmt) ...@@ -39,10 +39,7 @@ RC Stmt::create_stmt(Db *db, const Command &cmd, Stmt *&stmt)
} }
default: { default: {
LOG_INFO("Command::type %d doesn't need to create statement.", cmd.flag); LOG_INFO("Command::type %d doesn't need to create statement.", cmd.flag);
} } break;
break;
} }
return RC::UNIMPLENMENT; return RC::UNIMPLENMENT;
} }
...@@ -19,8 +19,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -19,8 +19,7 @@ See the Mulan PSL v2 for more details. */
class Db; class Db;
enum class StmtType enum class StmtType {
{
SELECT, SELECT,
INSERT, INSERT,
UPDATE, UPDATE,
...@@ -43,10 +42,8 @@ enum class StmtType ...@@ -43,10 +42,8 @@ enum class StmtType
PREDICATE, PREDICATE,
}; };
class Stmt class Stmt {
{
public: public:
Stmt() = default; Stmt() = default;
virtual ~Stmt() = default; virtual ~Stmt() = default;
...@@ -57,4 +54,3 @@ public: ...@@ -57,4 +54,3 @@ public:
private: private:
}; };
...@@ -15,7 +15,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -15,7 +15,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/stmt/update_stmt.h" #include "sql/stmt/update_stmt.h"
UpdateStmt::UpdateStmt(Table *table, Value *values, int value_amount) UpdateStmt::UpdateStmt(Table *table, Value *values, int value_amount)
: table_ (table), values_(values), value_amount_(value_amount) : table_(table), values_(values), value_amount_(value_amount)
{} {}
RC UpdateStmt::create(Db *db, const Updates &update, Stmt *&stmt) RC UpdateStmt::create(Db *db, const Updates &update, Stmt *&stmt)
......
...@@ -19,10 +19,8 @@ See the Mulan PSL v2 for more details. */ ...@@ -19,10 +19,8 @@ See the Mulan PSL v2 for more details. */
class Table; class Table;
class UpdateStmt : public Stmt class UpdateStmt : public Stmt {
{
public: public:
UpdateStmt() = default; UpdateStmt() = default;
UpdateStmt(Table *table, Value *values, int value_amount); UpdateStmt(Table *table, Value *values, int value_amount);
...@@ -30,13 +28,21 @@ public: ...@@ -30,13 +28,21 @@ public:
static RC create(Db *db, const Updates &update_sql, Stmt *&stmt); static RC create(Db *db, const Updates &update_sql, Stmt *&stmt);
public: public:
Table *table() const {return table_;} Table *table() const
Value *values() const { return values_; } {
int value_amount() const { return value_amount_; } return table_;
}
Value *values() const
{
return values_;
}
int value_amount() const
{
return value_amount_;
}
private: private:
Table *table_ = nullptr; Table *table_ = nullptr;
Value *values_ = nullptr; Value *values_ = nullptr;
int value_amount_ = 0; int value_amount_ = 0;
}; };
...@@ -194,8 +194,8 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off) ...@@ -194,8 +194,8 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off)
write_offset_ += CLOG_BLOCK_HDR_SIZE; write_offset_ += CLOG_BLOCK_HDR_SIZE;
return append_log_record(log_rec, start_off); return append_log_record(log_rec, start_off);
} else { } else {
if (logrec_left_len <= (CLOG_BLOCK_DATA_SIZE - log_block->log_block_hdr_.log_data_len_)) { //不需要再跨block存放 if (logrec_left_len <= (CLOG_BLOCK_DATA_SIZE - log_block->log_block_hdr_.log_data_len_)) { // 不需要再跨block存放
if (log_block->log_block_hdr_.log_data_len_ == 0) { //当前为新block if (log_block->log_block_hdr_.log_data_len_ == 0) { // 当前为新block
if (start_off == 0) { if (start_off == 0) {
log_block->log_block_hdr_.first_rec_offset_ = CLOG_BLOCK_HDR_SIZE; log_block->log_block_hdr_.first_rec_offset_ = CLOG_BLOCK_HDR_SIZE;
} else { } else {
...@@ -206,8 +206,8 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off) ...@@ -206,8 +206,8 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off)
write_offset_ += logrec_left_len; write_offset_ += logrec_left_len;
log_block->log_block_hdr_.log_data_len_ += logrec_left_len; log_block->log_block_hdr_.log_data_len_ += logrec_left_len;
start_off += logrec_left_len; start_off += logrec_left_len;
} else { //需要跨block } else { // 需要跨block
if (log_block->log_block_hdr_.log_data_len_ == 0) { //当前为新block if (log_block->log_block_hdr_.log_data_len_ == 0) { // 当前为新block
log_block->log_block_hdr_.first_rec_offset_ = CLOG_BLOCK_SIZE; log_block->log_block_hdr_.first_rec_offset_ = CLOG_BLOCK_SIZE;
} }
int32_t block_left_len = CLOG_BLOCK_DATA_SIZE - log_block->log_block_hdr_.log_data_len_; int32_t block_left_len = CLOG_BLOCK_DATA_SIZE - log_block->log_block_hdr_.log_data_len_;
...@@ -223,7 +223,7 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off) ...@@ -223,7 +223,7 @@ RC CLogBuffer::append_log_record(CLogRecord *log_rec, int &start_off)
RC CLogBuffer::flush_buffer(CLogFile *log_file) RC CLogBuffer::flush_buffer(CLogFile *log_file)
{ {
if (write_offset_ == CLOG_BUFFER_SIZE) { //如果是buffer满触发的下刷 if (write_offset_ == CLOG_BUFFER_SIZE) { // 如果是buffer满触发的下刷
CLogBlock *log_block = (CLogBlock *)buffer_; CLogBlock *log_block = (CLogBlock *)buffer_;
log_file->write(log_block->log_block_hdr_.log_block_no, CLOG_BUFFER_SIZE, buffer_); log_file->write(log_block->log_block_hdr_.log_block_no, CLOG_BUFFER_SIZE, buffer_);
write_block_offset_ = 0; write_block_offset_ = 0;
...@@ -324,7 +324,7 @@ RC CLogFile::recover(CLogMTRManager *mtr_mgr, CLogBuffer *log_buffer) ...@@ -324,7 +324,7 @@ RC CLogFile::recover(CLogMTRManager *mtr_mgr, CLogBuffer *log_buffer)
} }
} }
if (log_block->log_block_hdr_.log_data_len_ < CLOG_BLOCK_DATA_SIZE) { //最后一个block if (log_block->log_block_hdr_.log_data_len_ < CLOG_BLOCK_DATA_SIZE) { // 最后一个block
log_buffer->block_copy(0, log_block); log_buffer->block_copy(0, log_block);
log_buffer->set_write_block_offset(0); log_buffer->set_write_block_offset(0);
log_buffer->set_write_offset(log_block->log_block_hdr_.log_data_len_ + CLOG_BLOCK_HDR_SIZE); log_buffer->set_write_offset(log_block->log_block_hdr_.log_data_len_ + CLOG_BLOCK_HDR_SIZE);
...@@ -347,7 +347,7 @@ done: ...@@ -347,7 +347,7 @@ done:
RC CLogFile::block_recover(CLogBlock *block, int16_t &offset, CLogRecordBuf *logrec_buf, CLogRecord *&log_rec) RC CLogFile::block_recover(CLogBlock *block, int16_t &offset, CLogRecordBuf *logrec_buf, CLogRecord *&log_rec)
{ {
if (offset == CLOG_BLOCK_HDR_SIZE && if (offset == CLOG_BLOCK_HDR_SIZE &&
block->log_block_hdr_.first_rec_offset_ != CLOG_BLOCK_HDR_SIZE) { //跨block中的某部分(非第一部分) block->log_block_hdr_.first_rec_offset_ != CLOG_BLOCK_HDR_SIZE) { // 跨block中的某部分(非第一部分)
// 追加到logrec_buf // 追加到logrec_buf
memcpy(&logrec_buf->buffer_[logrec_buf->write_offset_], memcpy(&logrec_buf->buffer_[logrec_buf->write_offset_],
(char *)block + (int)offset, (char *)block + (int)offset,
...@@ -370,7 +370,7 @@ RC CLogFile::block_recover(CLogBlock *block, int16_t &offset, CLogRecordBuf *log ...@@ -370,7 +370,7 @@ RC CLogFile::block_recover(CLogBlock *block, int16_t &offset, CLogRecordBuf *log
if (logrec_hdr->logrec_len_ <= CLOG_BLOCK_SIZE - offset) { if (logrec_hdr->logrec_len_ <= CLOG_BLOCK_SIZE - offset) {
log_rec = new CLogRecord((char *)block + (int)offset); log_rec = new CLogRecord((char *)block + (int)offset);
offset += logrec_hdr->logrec_len_; offset += logrec_hdr->logrec_len_;
} else { //此时为跨block的第一部分 } else { // 此时为跨block的第一部分
// 开始写入logrec_buf // 开始写入logrec_buf
memcpy( memcpy(
&logrec_buf->buffer_[logrec_buf->write_offset_], (char *)block + (int)offset, CLOG_BLOCK_SIZE - offset); &logrec_buf->buffer_[logrec_buf->write_offset_], (char *)block + (int)offset, CLOG_BLOCK_SIZE - offset);
......
...@@ -27,7 +27,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -27,7 +27,7 @@ See the Mulan PSL v2 for more details. */
#include "storage/persist/persist.h" #include "storage/persist/persist.h"
#include "rc.h" #include "rc.h"
//固定文件大小 TODO: 循环文件组 // 固定文件大小 TODO: 循环文件组
#define CLOG_FILE_SIZE 48 * 1024 * 1024 #define CLOG_FILE_SIZE 48 * 1024 * 1024
#define CLOG_BUFFER_SIZE 4 * 1024 * 1024 #define CLOG_BUFFER_SIZE 4 * 1024 * 1024
#define TABLE_NAME_MAX_LEN 20 // TODO: 表名不要超过20字节 #define TABLE_NAME_MAX_LEN 20 // TODO: 表名不要超过20字节
...@@ -241,7 +241,7 @@ public: ...@@ -241,7 +241,7 @@ public:
RC clog_gen_record(CLogType flag, int32_t trx_id, CLogRecord *&log_rec, const char *table_name = nullptr, RC clog_gen_record(CLogType flag, int32_t trx_id, CLogRecord *&log_rec, const char *table_name = nullptr,
int data_len = 0, Record *rec = nullptr); int data_len = 0, Record *rec = nullptr);
//追加写到log_buffer // 追加写到log_buffer
RC clog_append_record(CLogRecord *log_rec); RC clog_append_record(CLogRecord *log_rec);
// 通常不需要在外部调用 // 通常不需要在外部调用
RC clog_sync(); RC clog_sync();
......
...@@ -66,7 +66,10 @@ public: ...@@ -66,7 +66,10 @@ public:
return comp_op_; return comp_op_;
} }
AttrType attr_type() const { return attr_type_; } AttrType attr_type() const
{
return attr_type_;
}
private: private:
ConDesc left_; ConDesc left_;
......
...@@ -72,7 +72,8 @@ RC Db::create_table(const char *table_name, int attribute_count, const AttrInfo ...@@ -72,7 +72,8 @@ RC Db::create_table(const char *table_name, int attribute_count, const AttrInfo
// 文件路径可以移到Table模块 // 文件路径可以移到Table模块
std::string table_file_path = table_meta_file(path_.c_str(), table_name); std::string table_file_path = table_meta_file(path_.c_str(), table_name);
Table *table = new Table(); Table *table = new Table();
rc = table->create(table_file_path.c_str(), table_name, path_.c_str(), attribute_count, attributes, get_clog_manager()); rc = table->create(
table_file_path.c_str(), table_name, path_.c_str(), attribute_count, attributes, get_clog_manager());
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to create table %s.", table_name); LOG_ERROR("Failed to create table %s.", table_name);
delete table; delete table;
...@@ -164,16 +165,17 @@ RC Db::recover() ...@@ -164,16 +165,17 @@ RC Db::recover()
CLogMTRManager *mtr_manager = clog_manager_->get_mtr_manager(); CLogMTRManager *mtr_manager = clog_manager_->get_mtr_manager();
for (auto it = mtr_manager->log_redo_list.begin(); it != mtr_manager->log_redo_list.end(); it++) { for (auto it = mtr_manager->log_redo_list.begin(); it != mtr_manager->log_redo_list.end(); it++) {
CLogRecord *clog_record = *it; CLogRecord *clog_record = *it;
if (clog_record->get_log_type() != CLogType::REDO_INSERT && clog_record->get_log_type() != CLogType::REDO_DELETE) { if (clog_record->get_log_type() != CLogType::REDO_INSERT &&
clog_record->get_log_type() != CLogType::REDO_DELETE) {
delete clog_record; delete clog_record;
continue; continue;
} }
auto find_iter = mtr_manager->trx_commited.find(clog_record->get_trx_id()); auto find_iter = mtr_manager->trx_commited.find(clog_record->get_trx_id());
if (find_iter == mtr_manager->trx_commited.end()) { if (find_iter == mtr_manager->trx_commited.end()) {
LOG_ERROR("CLog record without commit message! "); // unexpected error LOG_ERROR("CLog record without commit message! "); // unexpected error
delete clog_record; delete clog_record;
return RC::GENERIC_ERROR; return RC::GENERIC_ERROR;
} else if (find_iter->second == false ) { } else if (find_iter->second == false) {
delete clog_record; delete clog_record;
continue; continue;
} }
...@@ -184,7 +186,7 @@ RC Db::recover() ...@@ -184,7 +186,7 @@ RC Db::recover()
continue; continue;
} }
switch(clog_record->get_log_type()) { switch (clog_record->get_log_type()) {
case CLogType::REDO_INSERT: { case CLogType::REDO_INSERT: {
char *record_data = new char[clog_record->log_record_.ins.data_len_]; char *record_data = new char[clog_record->log_record_.ins.data_len_];
memcpy(record_data, clog_record->log_record_.ins.data_, clog_record->log_record_.ins.data_len_); memcpy(record_data, clog_record->log_record_.ins.data_, clog_record->log_record_.ins.data_len_);
...@@ -209,7 +211,7 @@ RC Db::recover() ...@@ -209,7 +211,7 @@ RC Db::recover()
LOG_ERROR("Failed to recover. rc=%d:%s", rc, strrc(rc)); LOG_ERROR("Failed to recover. rc=%d:%s", rc, strrc(rc));
break; break;
} }
if (max_trx_id < clog_record->get_trx_id()) { if (max_trx_id < clog_record->get_trx_id()) {
max_trx_id = clog_record->get_trx_id(); max_trx_id = clog_record->get_trx_id();
} }
...@@ -224,6 +226,7 @@ RC Db::recover() ...@@ -224,6 +226,7 @@ RC Db::recover()
return rc; return rc;
} }
CLogManager *Db::get_clog_manager() { CLogManager *Db::get_clog_manager()
{
return clog_manager_; return clog_manager_;
} }
\ No newline at end of file
...@@ -17,24 +17,35 @@ See the Mulan PSL v2 for more details. */ ...@@ -17,24 +17,35 @@ See the Mulan PSL v2 for more details. */
#include "storage/common/table.h" #include "storage/common/table.h"
#include "storage/common/field_meta.h" #include "storage/common/field_meta.h"
class Field class Field {
{
public: public:
Field() = default; Field() = default;
Field(const Table *table, const FieldMeta *field) : table_(table), field_(field) Field(const Table *table, const FieldMeta *field) : table_(table), field_(field)
{} {}
Field(const Field &) = default; Field(const Field &) = default;
const Table *table() const { return table_; } const Table *table() const
const FieldMeta *meta() const { return field_; } {
return table_;
}
const FieldMeta *meta() const
{
return field_;
}
AttrType attr_type() const AttrType attr_type() const
{ {
return field_->type(); return field_->type();
} }
const char *table_name() const { return table_->name(); } const char *table_name() const
const char *field_name() const { return field_->name(); } {
return table_->name();
}
const char *field_name() const
{
return field_->name();
}
void set_table(const Table *table) void set_table(const Table *table)
{ {
...@@ -44,6 +55,7 @@ public: ...@@ -44,6 +55,7 @@ public:
{ {
this->field_ = field; this->field_ = field;
} }
private: private:
const Table *table_ = nullptr; const Table *table_ = nullptr;
const FieldMeta *field_ = nullptr; const FieldMeta *field_ = nullptr;
......
...@@ -25,7 +25,6 @@ const static Json::StaticString FIELD_OFFSET("offset"); ...@@ -25,7 +25,6 @@ const static Json::StaticString FIELD_OFFSET("offset");
const static Json::StaticString FIELD_LEN("len"); const static Json::StaticString FIELD_LEN("len");
const static Json::StaticString FIELD_VISIBLE("visible"); const static Json::StaticString FIELD_VISIBLE("visible");
FieldMeta::FieldMeta() : attr_type_(AttrType::UNDEFINED), attr_offset_(-1), attr_len_(0), visible_(false) FieldMeta::FieldMeta() : attr_type_(AttrType::UNDEFINED), attr_offset_(-1), attr_len_(0), visible_(false)
{} {}
......
...@@ -51,8 +51,8 @@ Table::~Table() ...@@ -51,8 +51,8 @@ Table::~Table()
LOG_INFO("Table has been closed: %s", name()); LOG_INFO("Table has been closed: %s", name());
} }
RC Table::create( RC Table::create(const char *path, const char *name, const char *base_dir, int attribute_count,
const char *path, const char *name, const char *base_dir, int attribute_count, const AttrInfo attributes[], CLogManager *clog_manager) const AttrInfo attributes[], CLogManager *clog_manager)
{ {
if (common::is_blank(name)) { if (common::is_blank(name)) {
...@@ -273,7 +273,8 @@ RC Table::insert_record(Trx *trx, Record *record) ...@@ -273,7 +273,8 @@ RC Table::insert_record(Trx *trx, Record *record)
if (trx != nullptr) { if (trx != nullptr) {
// append clog record // append clog record
CLogRecord *clog_record = nullptr; CLogRecord *clog_record = nullptr;
rc = clog_manager_->clog_gen_record(CLogType::REDO_INSERT, trx->get_current_id(), clog_record, name(), table_meta_.record_size(), record); rc = clog_manager_->clog_gen_record(
CLogType::REDO_INSERT, trx->get_current_id(), clog_record, name(), table_meta_.record_size(), record);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to create a clog record. rc=%d:%s", rc, strrc(rc)); LOG_ERROR("Failed to create a clog record. rc=%d:%s", rc, strrc(rc));
return rc; return rc;
...@@ -432,16 +433,15 @@ static RC scan_record_reader_adapter(Record *record, void *context) ...@@ -432,16 +433,15 @@ static RC scan_record_reader_adapter(Record *record, void *context)
return RC::SUCCESS; return RC::SUCCESS;
} }
RC Table::scan_record(Trx *trx, ConditionFilter *filter, RC Table::scan_record(
int limit, void *context, Trx *trx, ConditionFilter *filter, int limit, void *context, void (*record_reader)(const char *data, void *context))
void (*record_reader)(const char *data, void *context))
{ {
RecordReaderScanAdapter adapter(record_reader, context); RecordReaderScanAdapter adapter(record_reader, context);
return scan_record(trx, filter, limit, (void *)&adapter, scan_record_reader_adapter); return scan_record(trx, filter, limit, (void *)&adapter, scan_record_reader_adapter);
} }
RC Table::scan_record(Trx *trx, ConditionFilter *filter, int limit, void *context, RC Table::scan_record(
RC (*record_reader)(Record *record, void *context)) Trx *trx, ConditionFilter *filter, int limit, void *context, RC (*record_reader)(Record *record, void *context))
{ {
if (nullptr == record_reader) { if (nullptr == record_reader) {
return RC::INVALID_ARGUMENT; return RC::INVALID_ARGUMENT;
...@@ -489,9 +489,8 @@ RC Table::scan_record(Trx *trx, ConditionFilter *filter, int limit, void *contex ...@@ -489,9 +489,8 @@ RC Table::scan_record(Trx *trx, ConditionFilter *filter, int limit, void *contex
return rc; return rc;
} }
RC Table::scan_record_by_index(Trx *trx, IndexScanner *scanner, ConditionFilter *filter, RC Table::scan_record_by_index(Trx *trx, IndexScanner *scanner, ConditionFilter *filter, int limit, void *context,
int limit, void *context, RC (*record_reader)(Record *, void *))
RC (*record_reader)(Record *, void *))
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
RID rid; RID rid;
...@@ -557,7 +556,9 @@ RC Table::create_index(Trx *trx, const char *index_name, const char *attribute_n ...@@ -557,7 +556,9 @@ RC Table::create_index(Trx *trx, const char *index_name, const char *attribute_n
} }
if (table_meta_.index(index_name) != nullptr || table_meta_.find_index_by_field((attribute_name))) { if (table_meta_.index(index_name) != nullptr || table_meta_.find_index_by_field((attribute_name))) {
LOG_INFO("Invalid input arguments, table name is %s, index %s exist or attribute %s exist index", LOG_INFO("Invalid input arguments, table name is %s, index %s exist or attribute %s exist index",
name(), index_name, attribute_name); name(),
index_name,
attribute_name);
return RC::SCHEMA_INDEX_EXIST; return RC::SCHEMA_INDEX_EXIST;
} }
...@@ -570,8 +571,7 @@ RC Table::create_index(Trx *trx, const char *index_name, const char *attribute_n ...@@ -570,8 +571,7 @@ RC Table::create_index(Trx *trx, const char *index_name, const char *attribute_n
IndexMeta new_index_meta; IndexMeta new_index_meta;
RC rc = new_index_meta.init(index_name, *field_meta); RC rc = new_index_meta.init(index_name, *field_meta);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_INFO("Failed to init IndexMeta in table:%s, index_name:%s, field_name:%s", LOG_INFO("Failed to init IndexMeta in table:%s, index_name:%s, field_name:%s", name(), index_name, attribute_name);
name(), index_name, attribute_name);
return rc; return rc;
} }
...@@ -689,24 +689,27 @@ RC Table::delete_record(Trx *trx, ConditionFilter *filter, int *deleted_count) ...@@ -689,24 +689,27 @@ RC Table::delete_record(Trx *trx, ConditionFilter *filter, int *deleted_count)
RC Table::delete_record(Trx *trx, Record *record) RC Table::delete_record(Trx *trx, Record *record)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
rc = delete_entry_of_indexes(record->data(), record->rid(), false); // 重复代码 refer to commit_delete rc = delete_entry_of_indexes(record->data(), record->rid(), false); // 重复代码 refer to commit_delete
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to delete indexes of record (rid=%d.%d). rc=%d:%s", LOG_ERROR("Failed to delete indexes of record (rid=%d.%d). rc=%d:%s",
record->rid().page_num, record->rid().slot_num, rc, strrc(rc)); record->rid().page_num,
record->rid().slot_num,
rc,
strrc(rc));
return rc; return rc;
} }
rc = record_handler_->delete_record(&record->rid()); rc = record_handler_->delete_record(&record->rid());
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to delete record (rid=%d.%d). rc=%d:%s", LOG_ERROR(
record->rid().page_num, record->rid().slot_num, rc, strrc(rc)); "Failed to delete record (rid=%d.%d). rc=%d:%s", record->rid().page_num, record->rid().slot_num, rc, strrc(rc));
return rc; return rc;
} }
if (trx != nullptr) { if (trx != nullptr) {
rc = trx->delete_record(this, record); rc = trx->delete_record(this, record);
CLogRecord *clog_record = nullptr; CLogRecord *clog_record = nullptr;
rc = clog_manager_->clog_gen_record(CLogType::REDO_DELETE, trx->get_current_id(), clog_record, name(), 0, record); rc = clog_manager_->clog_gen_record(CLogType::REDO_DELETE, trx->get_current_id(), clog_record, name(), 0, record);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
...@@ -726,7 +729,7 @@ RC Table::recover_delete_record(Record *record) ...@@ -726,7 +729,7 @@ RC Table::recover_delete_record(Record *record)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
rc = record_handler_->delete_record(&record->rid()); rc = record_handler_->delete_record(&record->rid());
return rc; return rc;
} }
...@@ -741,7 +744,10 @@ RC Table::commit_delete(Trx *trx, const RID &rid) ...@@ -741,7 +744,10 @@ RC Table::commit_delete(Trx *trx, const RID &rid)
rc = delete_entry_of_indexes(record.data(), record.rid(), false); rc = delete_entry_of_indexes(record.data(), record.rid(), false);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to delete indexes of record(rid=%d.%d). rc=%d:%s", LOG_ERROR("Failed to delete indexes of record(rid=%d.%d). rc=%d:%s",
rid.page_num, rid.slot_num, rc, strrc(rc)); // panic? rid.page_num,
rid.slot_num,
rc,
strrc(rc)); // panic?
} }
rc = record_handler_->delete_record(&rid); rc = record_handler_->delete_record(&rid);
...@@ -847,36 +853,31 @@ IndexScanner *Table::find_index_for_scan(const DefaultConditionFilter &filter) ...@@ -847,36 +853,31 @@ IndexScanner *Table::find_index_for_scan(const DefaultConditionFilter &filter)
bool left_inclusive = false; bool left_inclusive = false;
bool right_inclusive = false; bool right_inclusive = false;
switch (filter.comp_op()) { switch (filter.comp_op()) {
case EQUAL_TO: { case EQUAL_TO: {
left_key = (const char *)value_cond_desc->value.data(); left_key = (const char *)value_cond_desc->value.data();
right_key = (const char *)value_cond_desc->value.data(); right_key = (const char *)value_cond_desc->value.data();
left_inclusive = true; left_inclusive = true;
right_inclusive = true; right_inclusive = true;
} } break;
break; case LESS_EQUAL: {
case LESS_EQUAL: { right_key = (const char *)value_cond_desc->value.data();
right_key = (const char *)value_cond_desc->value.data(); right_inclusive = true;
right_inclusive = true; } break;
} case GREAT_EQUAL: {
break; left_key = (const char *)value_cond_desc->value.data();
case GREAT_EQUAL: { left_inclusive = true;
left_key = (const char *)value_cond_desc->value.data(); } break;
left_inclusive = true; case LESS_THAN: {
} right_key = (const char *)value_cond_desc->value.data();
break; right_inclusive = false;
case LESS_THAN: { } break;
right_key = (const char *)value_cond_desc->value.data(); case GREAT_THAN: {
right_inclusive = false; left_key = (const char *)value_cond_desc->value.data();
} left_inclusive = false;
break; } break;
case GREAT_THAN: { default: {
left_key = (const char *)value_cond_desc->value.data(); return nullptr;
left_inclusive = false; }
}
break;
default: {
return nullptr;
}
} }
if (filter.attr_type() == CHARS) { if (filter.attr_type() == CHARS) {
...@@ -918,7 +919,10 @@ RC Table::sync() ...@@ -918,7 +919,10 @@ RC Table::sync()
rc = index->sync(); rc = index->sync();
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to flush index's pages. table=%s, index=%s, rc=%d:%s", LOG_ERROR("Failed to flush index's pages. table=%s, index=%s, rc=%d:%s",
name(), index->index_meta().name(), rc, strrc(rc)); name(),
index->index_meta().name(),
rc,
strrc(rc));
return rc; return rc;
} }
} }
......
...@@ -82,7 +82,8 @@ RC TableMeta::init(const char *name, int field_num, const AttrInfo attributes[]) ...@@ -82,7 +82,8 @@ RC TableMeta::init(const char *name, int field_num, const AttrInfo attributes[])
for (int i = 0; i < field_num; i++) { for (int i = 0; i < field_num; i++) {
const AttrInfo &attr_info = attributes[i]; const AttrInfo &attr_info = attributes[i];
rc = fields_[i + sys_fields_.size()].init(attr_info.name.c_str(), attr_info.type, field_offset, attr_info.length, true); rc = fields_[i + sys_fields_.size()].init(
attr_info.name.c_str(), attr_info.type, field_offset, attr_info.length, true);
if (rc != RC::SUCCESS) { if (rc != RC::SUCCESS) {
LOG_ERROR("Failed to init field meta. table name=%s, field name: %s", name, attr_info.name.c_str()); LOG_ERROR("Failed to init field meta. table name=%s, field name: %s", name, attr_info.name.c_str());
return rc; return rc;
......
...@@ -41,8 +41,11 @@ public: ...@@ -41,8 +41,11 @@ public:
const FieldMeta *field(int index) const; const FieldMeta *field(int index) const;
const FieldMeta *field(const char *name) const; const FieldMeta *field(const char *name) const;
const FieldMeta *find_field_by_offset(int offset) const; const FieldMeta *find_field_by_offset(int offset) const;
const std::vector<FieldMeta> *field_metas() const { return &fields_; } const std::vector<FieldMeta> *field_metas() const
int field_num() const; // sys field included {
return &fields_;
}
int field_num() const; // sys field included
int sys_field_num() const; int sys_field_num() const;
const IndexMeta *index(const char *name) const; const IndexMeta *index(const char *name) const;
......
...@@ -37,7 +37,7 @@ BPFrameManager::BPFrameManager(const char *name) : allocator_(name) ...@@ -37,7 +37,7 @@ BPFrameManager::BPFrameManager(const char *name) : allocator_(name)
RC BPFrameManager::init(int pool_num) RC BPFrameManager::init(int pool_num)
{ {
int ret = allocator_.init(false, pool_num); int ret = allocator_.init(false, pool_num);
if (ret == 0) { if (ret == 0) {
return RC::SUCCESS; return RC::SUCCESS;
} }
...@@ -57,12 +57,12 @@ RC BPFrameManager::cleanup() ...@@ -57,12 +57,12 @@ RC BPFrameManager::cleanup()
Frame *BPFrameManager::begin_purge() Frame *BPFrameManager::begin_purge()
{ {
Frame *frame_can_purge = nullptr; Frame *frame_can_purge = nullptr;
auto purge_finder = [&frame_can_purge](const BPFrameId &frame_id, Frame * const frame) { auto purge_finder = [&frame_can_purge](const BPFrameId &frame_id, Frame *const frame) {
if (frame->can_purge()) { if (frame->can_purge()) {
frame_can_purge = frame; frame_can_purge = frame;
return false; // false to break the progress return false; // false to break the progress
} }
return true; // true continue to look up return true; // true continue to look up
}; };
frames_.foreach_reverse(purge_finder); frames_.foreach_reverse(purge_finder);
return frame_can_purge; return frame_can_purge;
...@@ -87,7 +87,7 @@ Frame *BPFrameManager::alloc(int file_desc, PageNum page_num) ...@@ -87,7 +87,7 @@ Frame *BPFrameManager::alloc(int file_desc, PageNum page_num)
bool found = frames_.get(frame_id, frame); bool found = frames_.get(frame_id, frame);
if (found) { if (found) {
// assert (frame != nullptr); // assert (frame != nullptr);
return nullptr; // should use get return nullptr; // should use get
} }
frame = allocator_.alloc(); frame = allocator_.alloc();
...@@ -106,7 +106,10 @@ RC BPFrameManager::free(int file_desc, PageNum page_num, Frame *frame) ...@@ -106,7 +106,10 @@ RC BPFrameManager::free(int file_desc, PageNum page_num, Frame *frame)
bool found = frames_.get(frame_id, frame_source); bool found = frames_.get(frame_id, frame_source);
if (!found || frame != frame_source) { if (!found || frame != frame_source) {
LOG_WARN("failed to find frame or got frame not match. file_desc=%d, PageNum=%d, frame_source=%p, frame=%p", LOG_WARN("failed to find frame or got frame not match. file_desc=%d, PageNum=%d, frame_source=%p, frame=%p",
file_desc, page_num, frame_source, frame); file_desc,
page_num,
frame_source,
frame);
return RC::GENERIC_ERROR; return RC::GENERIC_ERROR;
} }
...@@ -120,13 +123,13 @@ std::list<Frame *> BPFrameManager::find_list(int file_desc) ...@@ -120,13 +123,13 @@ std::list<Frame *> BPFrameManager::find_list(int file_desc)
std::lock_guard<std::mutex> lock_guard(lock_); std::lock_guard<std::mutex> lock_guard(lock_);
std::list<Frame *> frames; std::list<Frame *> frames;
auto fetcher = [&frames, file_desc](const BPFrameId &frame_id, Frame * const frame) -> bool { auto fetcher = [&frames, file_desc](const BPFrameId &frame_id, Frame *const frame) -> bool {
if (file_desc == frame_id.file_desc()) { if (file_desc == frame_id.file_desc()) {
frames.push_back(frame); frames.push_back(frame);
} }
return true; return true;
}; };
frames_.foreach(fetcher); frames_.foreach (fetcher);
return frames; return frames;
} }
...@@ -168,9 +171,8 @@ RC BufferPoolIterator::reset() ...@@ -168,9 +171,8 @@ RC BufferPoolIterator::reset()
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
DiskBufferPool::DiskBufferPool(BufferPoolManager &bp_manager, BPFrameManager &frame_manager) DiskBufferPool::DiskBufferPool(BufferPoolManager &bp_manager, BPFrameManager &frame_manager)
: bp_manager_(bp_manager), frame_manager_(frame_manager) : bp_manager_(bp_manager), frame_manager_(frame_manager)
{ {}
}
DiskBufferPool::~DiskBufferPool() DiskBufferPool::~DiskBufferPool()
{ {
...@@ -255,7 +257,6 @@ RC DiskBufferPool::get_this_page(PageNum page_num, Frame **frame) ...@@ -255,7 +257,6 @@ RC DiskBufferPool::get_this_page(PageNum page_num, Frame **frame)
used_match_frame->pin_count_++; used_match_frame->pin_count_++;
used_match_frame->acc_time_ = current_time(); used_match_frame->acc_time_ = current_time();
*frame = used_match_frame; *frame = used_match_frame;
return RC::SUCCESS; return RC::SUCCESS;
} }
...@@ -296,7 +297,7 @@ RC DiskBufferPool::allocate_page(Frame **frame) ...@@ -296,7 +297,7 @@ RC DiskBufferPool::allocate_page(Frame **frame)
(file_header_->allocated_pages)++; (file_header_->allocated_pages)++;
file_header_->bitmap[byte] |= (1 << bit); file_header_->bitmap[byte] |= (1 << bit);
// TODO, do we need clean the loaded page's data? // TODO, do we need clean the loaded page's data?
hdr_frame_->mark_dirty(); hdr_frame_->mark_dirty();
return get_this_page(i, frame); return get_this_page(i, frame);
} }
} }
...@@ -304,7 +305,8 @@ RC DiskBufferPool::allocate_page(Frame **frame) ...@@ -304,7 +305,8 @@ RC DiskBufferPool::allocate_page(Frame **frame)
if (file_header_->page_count >= BPFileHeader::MAX_PAGE_NUM) { if (file_header_->page_count >= BPFileHeader::MAX_PAGE_NUM) {
LOG_WARN("file buffer pool is full. page count %d, max page count %d", LOG_WARN("file buffer pool is full. page count %d, max page count %d",
file_header_->page_count, BPFileHeader::MAX_PAGE_NUM); file_header_->page_count,
BPFileHeader::MAX_PAGE_NUM);
return BUFFERPOOL_NOBUF; return BUFFERPOOL_NOBUF;
} }
...@@ -343,7 +345,7 @@ RC DiskBufferPool::allocate_page(Frame **frame) ...@@ -343,7 +345,7 @@ RC DiskBufferPool::allocate_page(Frame **frame)
RC DiskBufferPool::unpin_page(Frame *frame) RC DiskBufferPool::unpin_page(Frame *frame)
{ {
assert(frame->pin_count_ >= 1); ASSERT(frame->pin_count_ >= 1, "Page %d 's pin_count is smaller than 1", frame->page_num());
if (--frame->pin_count_ == 0) { if (--frame->pin_count_ == 0) {
PageNum page_num = frame->page_num(); PageNum page_num = frame->page_num();
auto pages_it = disposed_pages.find(page_num); auto pages_it = disposed_pages.find(page_num);
...@@ -383,7 +385,9 @@ RC DiskBufferPool::purge_frame(PageNum page_num, Frame *buf) ...@@ -383,7 +385,9 @@ RC DiskBufferPool::purge_frame(PageNum page_num, Frame *buf)
{ {
if (buf->pin_count_ > 0) { if (buf->pin_count_ > 0) {
LOG_INFO("Begin to free page %d of %d(file id), but it's pinned, pin_count:%d.", LOG_INFO("Begin to free page %d of %d(file id), but it's pinned, pin_count:%d.",
buf->page_num(), buf->file_desc_, buf->pin_count_); buf->page_num(),
buf->file_desc_,
buf->pin_count_);
return RC::LOCKED_UNLOCK; return RC::LOCKED_UNLOCK;
} }
...@@ -423,7 +427,9 @@ RC DiskBufferPool::purge_all_pages() ...@@ -423,7 +427,9 @@ RC DiskBufferPool::purge_all_pages()
Frame *frame = *it; Frame *frame = *it;
if (frame->pin_count_ > 0) { if (frame->pin_count_ > 0) {
LOG_WARN("The page has been pinned, file_desc:%d, pagenum:%d, pin_count=%d", LOG_WARN("The page has been pinned, file_desc:%d, pagenum:%d, pin_count=%d",
frame->file_desc_, frame->page_.page_num, frame->pin_count_); frame->file_desc_,
frame->page_.page_num,
frame->pin_count_);
continue; continue;
} }
if (frame->dirty_) { if (frame->dirty_) {
...@@ -441,13 +447,17 @@ RC DiskBufferPool::purge_all_pages() ...@@ -441,13 +447,17 @@ RC DiskBufferPool::purge_all_pages()
RC DiskBufferPool::check_all_pages_unpinned() RC DiskBufferPool::check_all_pages_unpinned()
{ {
std::list<Frame *> frames = frame_manager_.find_list(file_desc_); std::list<Frame *> frames = frame_manager_.find_list(file_desc_);
for (auto & frame : frames) { for (auto &frame : frames) {
if (frame->page_num() == BP_HEADER_PAGE && frame->pin_count_ > 1) { if (frame->page_num() == BP_HEADER_PAGE && frame->pin_count_ > 1) {
LOG_WARN("This page has been pinned. file desc=%d, page num:%d, pin count=%d", LOG_WARN("This page has been pinned. file desc=%d, page num:%d, pin count=%d",
file_desc_, frame->page_num(), frame->pin_count_); file_desc_,
frame->page_num(),
frame->pin_count_);
} else if (frame->page_num() != BP_HEADER_PAGE && frame->pin_count_ > 0) { } else if (frame->page_num() != BP_HEADER_PAGE && frame->pin_count_ > 0) {
LOG_WARN("This page has been pinned. file desc=%d, page num:%d, pin count=%d", LOG_WARN("This page has been pinned. file desc=%d, page num:%d, pin count=%d",
file_desc_, frame->page_num(), frame->pin_count_); file_desc_,
frame->page_num(),
frame->pin_count_);
} }
} }
LOG_INFO("all pages have been checked of file desc %d", file_desc_); LOG_INFO("all pages have been checked of file desc %d", file_desc_);
...@@ -549,8 +559,7 @@ RC DiskBufferPool::load_page(PageNum page_num, Frame *frame) ...@@ -549,8 +559,7 @@ RC DiskBufferPool::load_page(PageNum page_num, Frame *frame)
{ {
s64_t offset = ((s64_t)page_num) * sizeof(Page); s64_t offset = ((s64_t)page_num) * sizeof(Page);
if (lseek(file_desc_, offset, SEEK_SET) == -1) { if (lseek(file_desc_, offset, SEEK_SET) == -1) {
LOG_ERROR("Failed to load page %s:%d, due to failed to lseek:%s.", LOG_ERROR("Failed to load page %s:%d, due to failed to lseek:%s.", file_name_.c_str(), page_num, strerror(errno));
file_name_.c_str(), page_num, strerror(errno));
return RC::IOERR_SEEK; return RC::IOERR_SEEK;
} }
...@@ -558,7 +567,11 @@ RC DiskBufferPool::load_page(PageNum page_num, Frame *frame) ...@@ -558,7 +567,11 @@ RC DiskBufferPool::load_page(PageNum page_num, Frame *frame)
int ret = readn(file_desc_, &(frame->page_), sizeof(Page)); int ret = readn(file_desc_, &(frame->page_), sizeof(Page));
if (ret != 0) { if (ret != 0) {
LOG_ERROR("Failed to load page %s:%d, due to failed to read data:%s, ret=%d, page count=%d", LOG_ERROR("Failed to load page %s:%d, due to failed to read data:%s, ret=%d, page count=%d",
file_name_.c_str(), page_num, strerror(errno), ret, file_header_->allocated_pages); file_name_.c_str(),
page_num,
strerror(errno),
ret,
file_header_->allocated_pages);
return RC::IOERR_READ; return RC::IOERR_READ;
} }
return RC::SUCCESS; return RC::SUCCESS;
...@@ -583,7 +596,7 @@ BufferPoolManager::~BufferPoolManager() ...@@ -583,7 +596,7 @@ BufferPoolManager::~BufferPoolManager()
{ {
std::unordered_map<std::string, DiskBufferPool *> tmp_bps; std::unordered_map<std::string, DiskBufferPool *> tmp_bps;
tmp_bps.swap(buffer_pools_); tmp_bps.swap(buffer_pools_);
for (auto &iter : tmp_bps) { for (auto &iter : tmp_bps) {
delete iter.second; delete iter.second;
} }
...@@ -634,10 +647,10 @@ RC BufferPoolManager::create_file(const char *file_name) ...@@ -634,10 +647,10 @@ RC BufferPoolManager::create_file(const char *file_name)
return RC::SUCCESS; return RC::SUCCESS;
} }
RC BufferPoolManager::open_file(const char *_file_name, DiskBufferPool *& _bp) RC BufferPoolManager::open_file(const char *_file_name, DiskBufferPool *&_bp)
{ {
std::string file_name(_file_name); std::string file_name(_file_name);
if (buffer_pools_.find(file_name) != buffer_pools_.end()) { if (buffer_pools_.find(file_name) != buffer_pools_.end()) {
LOG_WARN("file already opened. file name=%s", _file_name); LOG_WARN("file already opened. file name=%s", _file_name);
return RC::BUFFERPOOL_OPEN; return RC::BUFFERPOOL_OPEN;
......
...@@ -53,9 +53,9 @@ struct Page { ...@@ -53,9 +53,9 @@ struct Page {
* 效率非常低,你有办法优化吗? * 效率非常低,你有办法优化吗?
*/ */
struct BPFileHeader { struct BPFileHeader {
int32_t page_count; //! 当前文件一共有多少个页面 int32_t page_count; //! 当前文件一共有多少个页面
int32_t allocated_pages; //! 已经分配了多少个页面 int32_t allocated_pages; //! 已经分配了多少个页面
char bitmap[0]; //! 页面分配位图, 第0个页面(就是当前页面),总是1 char bitmap[0]; //! 页面分配位图, 第0个页面(就是当前页面),总是1
/** /**
* 能够分配的最大的页面个数,即bitmap的字节数 乘以8 * 能够分配的最大的页面个数,即bitmap的字节数 乘以8
...@@ -63,8 +63,7 @@ struct BPFileHeader { ...@@ -63,8 +63,7 @@ struct BPFileHeader {
static const int MAX_PAGE_NUM = (BP_PAGE_DATA_SIZE - sizeof(page_count) - sizeof(allocated_pages)) * 8; static const int MAX_PAGE_NUM = (BP_PAGE_DATA_SIZE - sizeof(page_count) - sizeof(allocated_pages)) * 8;
}; };
class Frame class Frame {
{
public: public:
void clear_page() void clear_page()
{ {
...@@ -85,11 +84,13 @@ public: ...@@ -85,11 +84,13 @@ public:
* 标记指定页面为“脏”页。如果修改了页面的内容,则应调用此函数, * 标记指定页面为“脏”页。如果修改了页面的内容,则应调用此函数,
* 以便该页面被淘汰出缓冲区时系统将新的页面数据写入磁盘文件 * 以便该页面被淘汰出缓冲区时系统将新的页面数据写入磁盘文件
*/ */
void mark_dirty() { void mark_dirty()
{
dirty_ = true; dirty_ = true;
} }
char *data() { char *data()
{
return page_.data; return page_.data;
} }
...@@ -106,21 +107,20 @@ public: ...@@ -106,21 +107,20 @@ public:
{ {
return pin_count_ <= 0; return pin_count_ <= 0;
} }
private: private:
friend class DiskBufferPool; friend class DiskBufferPool;
bool dirty_ = false; bool dirty_ = false;
unsigned int pin_count_ = 0; unsigned int pin_count_ = 0;
unsigned long acc_time_ = 0; unsigned long acc_time_ = 0;
int file_desc_ = -1; int file_desc_ = -1;
Page page_; Page page_;
}; };
class BPFrameId class BPFrameId {
{ public:
public: BPFrameId(int file_desc, PageNum page_num) : file_desc_(file_desc), page_num_(page_num)
BPFrameId(int file_desc, PageNum page_num) :
file_desc_(file_desc), page_num_(page_num)
{} {}
bool equal_to(const BPFrameId &other) const bool equal_to(const BPFrameId &other) const
...@@ -128,7 +128,7 @@ public: ...@@ -128,7 +128,7 @@ public:
return file_desc_ == other.file_desc_ && page_num_ == other.page_num_; return file_desc_ == other.file_desc_ && page_num_ == other.page_num_;
} }
bool operator== (const BPFrameId &other) const bool operator==(const BPFrameId &other) const
{ {
return this->equal_to(other); return this->equal_to(other);
} }
...@@ -138,16 +138,21 @@ public: ...@@ -138,16 +138,21 @@ public:
return static_cast<size_t>(file_desc_) << 32L | page_num_; return static_cast<size_t>(file_desc_) << 32L | page_num_;
} }
int file_desc() const { return file_desc_; } int file_desc() const
PageNum page_num() const { return page_num_; } {
return file_desc_;
}
PageNum page_num() const
{
return page_num_;
}
private: private:
int file_desc_; int file_desc_;
PageNum page_num_; PageNum page_num_;
}; };
class BPFrameManager class BPFrameManager {
{
public: public:
BPFrameManager(const char *tag); BPFrameManager(const char *tag);
...@@ -172,17 +177,24 @@ public: ...@@ -172,17 +177,24 @@ public:
*/ */
Frame *begin_purge(); Frame *begin_purge();
size_t frame_num() const { return frames_.count(); } size_t frame_num() const
{
return frames_.count();
}
/** /**
* 测试使用。返回已经从内存申请的个数 * 测试使用。返回已经从内存申请的个数
*/ */
size_t total_frame_num() const { return allocator_.get_size(); } size_t total_frame_num() const
{
return allocator_.get_size();
}
private: private:
class BPFrameIdHasher { class BPFrameIdHasher {
public: public:
size_t operator() (const BPFrameId &frame_id) const { size_t operator()(const BPFrameId &frame_id) const
{
return frame_id.hash(); return frame_id.hash();
} }
}; };
...@@ -194,8 +206,7 @@ private: ...@@ -194,8 +206,7 @@ private:
FrameAllocator allocator_; FrameAllocator allocator_;
}; };
class BufferPoolIterator class BufferPoolIterator {
{
public: public:
BufferPoolIterator(); BufferPoolIterator();
~BufferPoolIterator(); ~BufferPoolIterator();
...@@ -204,13 +215,13 @@ public: ...@@ -204,13 +215,13 @@ public:
bool has_next(); bool has_next();
PageNum next(); PageNum next();
RC reset(); RC reset();
private: private:
common::Bitmap bitmap_; common::Bitmap bitmap_;
PageNum current_page_num_ = -1; PageNum current_page_num_ = -1;
}; };
class DiskBufferPool class DiskBufferPool {
{
public: public:
DiskBufferPool(BufferPoolManager &bp_manager, BPFrameManager &frame_manager); DiskBufferPool(BufferPoolManager &bp_manager, BPFrameManager &frame_manager);
~DiskBufferPool(); ~DiskBufferPool();
...@@ -288,6 +299,7 @@ public: ...@@ -288,6 +299,7 @@ public:
* 回放日志时处理page0中已被认定为不存在的page * 回放日志时处理page0中已被认定为不存在的page
*/ */
RC recover_page(PageNum page_num); RC recover_page(PageNum page_num);
protected: protected:
protected: protected:
RC allocate_frame(PageNum page_num, Frame **buf); RC allocate_frame(PageNum page_num, Frame **buf);
...@@ -305,19 +317,18 @@ protected: ...@@ -305,19 +317,18 @@ protected:
private: private:
BufferPoolManager &bp_manager_; BufferPoolManager &bp_manager_;
BPFrameManager & frame_manager_; BPFrameManager &frame_manager_;
std::string file_name_; std::string file_name_;
int file_desc_ = -1; int file_desc_ = -1;
Frame * hdr_frame_ = nullptr; Frame *hdr_frame_ = nullptr;
BPFileHeader * file_header_ = nullptr; BPFileHeader *file_header_ = nullptr;
std::set<PageNum> disposed_pages; std::set<PageNum> disposed_pages;
private: private:
friend class BufferPoolIterator; friend class BufferPoolIterator;
}; };
class BufferPoolManager class BufferPoolManager {
{
public: public:
BufferPoolManager(); BufferPoolManager();
~BufferPoolManager(); ~BufferPoolManager();
...@@ -331,7 +342,7 @@ public: ...@@ -331,7 +342,7 @@ public:
public: public:
static void set_instance(BufferPoolManager *bpm); static void set_instance(BufferPoolManager *bpm);
static BufferPoolManager &instance(); static BufferPoolManager &instance();
private: private:
BPFrameManager frame_manager_{"BufPool"}; BPFrameManager frame_manager_{"BufPool"};
std::unordered_map<std::string, DiskBufferPool *> buffer_pools_; std::unordered_map<std::string, DiskBufferPool *> buffer_pools_;
......
...@@ -79,8 +79,7 @@ RC BplusTreeIndex::open(const char *file_name, const IndexMeta &index_meta, cons ...@@ -79,8 +79,7 @@ RC BplusTreeIndex::open(const char *file_name, const IndexMeta &index_meta, cons
RC BplusTreeIndex::close() RC BplusTreeIndex::close()
{ {
if (inited_) { if (inited_) {
LOG_INFO("Begin to close index, index:%s, field:%s", LOG_INFO("Begin to close index, index:%s, field:%s", index_meta_.name(), index_meta_.field());
index_meta_.name(), index_meta_.field());
index_handler_.close(); index_handler_.close();
inited_ = false; inited_ = false;
} }
...@@ -98,8 +97,8 @@ RC BplusTreeIndex::delete_entry(const char *record, const RID *rid) ...@@ -98,8 +97,8 @@ RC BplusTreeIndex::delete_entry(const char *record, const RID *rid)
return index_handler_.delete_entry(record + field_meta_.offset(), rid); return index_handler_.delete_entry(record + field_meta_.offset(), rid);
} }
IndexScanner *BplusTreeIndex::create_scanner(const char *left_key, int left_len, bool left_inclusive, IndexScanner *BplusTreeIndex::create_scanner(
const char *right_key, int right_len, bool right_inclusive) const char *left_key, int left_len, bool left_inclusive, const char *right_key, int right_len, bool right_inclusive)
{ {
BplusTreeIndexScanner *index_scanner = new BplusTreeIndexScanner(index_handler_); BplusTreeIndexScanner *index_scanner = new BplusTreeIndexScanner(index_handler_);
RC rc = index_scanner->open(left_key, left_len, left_inclusive, right_key, right_len, right_inclusive); RC rc = index_scanner->open(left_key, left_len, left_inclusive, right_key, right_len, right_inclusive);
...@@ -125,8 +124,8 @@ BplusTreeIndexScanner::~BplusTreeIndexScanner() noexcept ...@@ -125,8 +124,8 @@ BplusTreeIndexScanner::~BplusTreeIndexScanner() noexcept
tree_scanner_.close(); tree_scanner_.close();
} }
RC BplusTreeIndexScanner::open(const char *left_key, int left_len, bool left_inclusive, RC BplusTreeIndexScanner::open(
const char *right_key, int right_len, bool right_inclusive) const char *left_key, int left_len, bool left_inclusive, const char *right_key, int right_len, bool right_inclusive)
{ {
return tree_scanner_.open(left_key, left_len, left_inclusive, right_key, right_len, right_inclusive); return tree_scanner_.open(left_key, left_len, left_inclusive, right_key, right_len, right_inclusive);
} }
......
...@@ -33,8 +33,8 @@ public: ...@@ -33,8 +33,8 @@ public:
/** /**
* 扫描指定范围的数据 * 扫描指定范围的数据
*/ */
IndexScanner *create_scanner(const char *left_key, int left_len, bool left_inclusive, IndexScanner *create_scanner(const char *left_key, int left_len, bool left_inclusive, const char *right_key,
const char *right_key, int right_len, bool right_inclusive) override; int right_len, bool right_inclusive) override;
RC sync() override; RC sync() override;
...@@ -51,8 +51,9 @@ public: ...@@ -51,8 +51,9 @@ public:
RC next_entry(RID *rid) override; RC next_entry(RID *rid) override;
RC destroy() override; RC destroy() override;
RC open(const char *left_key, int left_len, bool left_inclusive, RC open(const char *left_key, int left_len, bool left_inclusive, const char *right_key, int right_len,
const char *right_key, int right_len, bool right_inclusive); bool right_inclusive);
private: private:
BplusTreeScanner tree_scanner_; BplusTreeScanner tree_scanner_;
}; };
......
...@@ -46,8 +46,8 @@ public: ...@@ -46,8 +46,8 @@ public:
virtual RC insert_entry(const char *record, const RID *rid) = 0; virtual RC insert_entry(const char *record, const RID *rid) = 0;
virtual RC delete_entry(const char *record, const RID *rid) = 0; virtual RC delete_entry(const char *record, const RID *rid) = 0;
virtual IndexScanner *create_scanner(const char *left_key, int left_len, bool left_inclusive, virtual IndexScanner *create_scanner(const char *left_key, int left_len, bool left_inclusive, const char *right_key,
const char *right_key, int right_len, bool right_inclusive) = 0; int right_len, bool right_inclusive) = 0;
virtual RC sync() = 0; virtual RC sync() = 0;
......
...@@ -35,7 +35,7 @@ RC PersistHandler::create_file(const char *file_name) ...@@ -35,7 +35,7 @@ RC PersistHandler::create_file(const char *file_name)
} else if (!file_name_.empty()) { } else if (!file_name_.empty()) {
LOG_ERROR("Failed to create %s, because a file is already bound.", file_name); LOG_ERROR("Failed to create %s, because a file is already bound.", file_name);
rc = RC::FILE_BOUND; rc = RC::FILE_BOUND;
} else if (access(file_name, F_OK) != -1){ } else if (access(file_name, F_OK) != -1) {
LOG_WARN("Failed to create %s, because file already exist.", file_name); LOG_WARN("Failed to create %s, because file already exist.", file_name);
rc = RC::FILE_EXIST; rc = RC::FILE_EXIST;
} else { } else {
...@@ -43,7 +43,7 @@ RC PersistHandler::create_file(const char *file_name) ...@@ -43,7 +43,7 @@ RC PersistHandler::create_file(const char *file_name)
fd = open(file_name, O_RDWR | O_CREAT | O_EXCL, S_IREAD | S_IWRITE); fd = open(file_name, O_RDWR | O_CREAT | O_EXCL, S_IREAD | S_IWRITE);
if (fd < 0) { if (fd < 0) {
LOG_ERROR("Failed to create %s, due to %s.", file_name, strerror(errno)); LOG_ERROR("Failed to create %s, due to %s.", file_name, strerror(errno));
rc = RC::FILE_CREATE; rc = RC::FILE_CREATE;
} else { } else {
file_name_ = file_name; file_name_ = file_name;
close(fd); close(fd);
...@@ -61,7 +61,7 @@ RC PersistHandler::open_file(const char *file_name) ...@@ -61,7 +61,7 @@ RC PersistHandler::open_file(const char *file_name)
if (file_name == nullptr) { if (file_name == nullptr) {
if (file_name_.empty()) { if (file_name_.empty()) {
LOG_ERROR("Failed to open file, because no file name."); LOG_ERROR("Failed to open file, because no file name.");
rc = RC::FILE_NAME; rc = RC::FILE_NAME;
} else { } else {
if ((fd = open(file_name_.c_str(), O_RDWR)) < 0) { if ((fd = open(file_name_.c_str(), O_RDWR)) < 0) {
LOG_ERROR("Failed to open file %s, because %s.", file_name_.c_str(), strerror(errno)); LOG_ERROR("Failed to open file %s, because %s.", file_name_.c_str(), strerror(errno));
...@@ -86,7 +86,7 @@ RC PersistHandler::open_file(const char *file_name) ...@@ -86,7 +86,7 @@ RC PersistHandler::open_file(const char *file_name)
} }
} }
} }
return rc; return rc;
} }
...@@ -143,7 +143,11 @@ RC PersistHandler::write_file(int size, const char *data, int64_t *out_size) ...@@ -143,7 +143,11 @@ RC PersistHandler::write_file(int size, const char *data, int64_t *out_size)
} else { } else {
int64_t write_size = 0; int64_t write_size = 0;
if ((write_size = write(file_desc_, data, size)) != size) { if ((write_size = write(file_desc_, data, size)) != size) {
LOG_ERROR("Failed to write %d:%s due to %s. Write size: %lld", file_desc_, file_name_.c_str(), strerror(errno), write_size); LOG_ERROR("Failed to write %d:%s due to %s. Write size: %lld",
file_desc_,
file_name_.c_str(),
strerror(errno),
write_size);
rc = RC::FILE_WRITE; rc = RC::FILE_WRITE;
} }
if (out_size != nullptr) { if (out_size != nullptr) {
...@@ -154,7 +158,6 @@ RC PersistHandler::write_file(int size, const char *data, int64_t *out_size) ...@@ -154,7 +158,6 @@ RC PersistHandler::write_file(int size, const char *data, int64_t *out_size)
return rc; return rc;
} }
RC PersistHandler::write_at(uint64_t offset, int size, const char *data, int64_t *out_size) RC PersistHandler::write_at(uint64_t offset, int size, const char *data, int64_t *out_size)
{ {
RC rc = RC::SUCCESS; RC rc = RC::SUCCESS;
...@@ -166,12 +169,21 @@ RC PersistHandler::write_at(uint64_t offset, int size, const char *data, int64_t ...@@ -166,12 +169,21 @@ RC PersistHandler::write_at(uint64_t offset, int size, const char *data, int64_t
rc = RC::FILE_NOT_OPENED; rc = RC::FILE_NOT_OPENED;
} else { } else {
if (lseek(file_desc_, offset, SEEK_SET) == off_t(-1)) { if (lseek(file_desc_, offset, SEEK_SET) == off_t(-1)) {
LOG_ERROR("Failed to write %lld of %d:%s due to failed to seek %s.", offset, file_desc_, file_name_.c_str(), strerror(errno)); LOG_ERROR("Failed to write %lld of %d:%s due to failed to seek %s.",
offset,
file_desc_,
file_name_.c_str(),
strerror(errno));
rc = RC::FILE_SEEK; rc = RC::FILE_SEEK;
} else { } else {
int64_t write_size = 0; int64_t write_size = 0;
if ((write_size = write(file_desc_, data, size)) != size) { if ((write_size = write(file_desc_, data, size)) != size) {
LOG_ERROR("Failed to write %llu of %d:%s due to %s. Write size: %lld", offset, file_desc_, file_name_.c_str(), strerror(errno), write_size); LOG_ERROR("Failed to write %llu of %d:%s due to %s. Write size: %lld",
offset,
file_desc_,
file_name_.c_str(),
strerror(errno),
write_size);
rc = RC::FILE_WRITE; rc = RC::FILE_WRITE;
} }
if (out_size != nullptr) { if (out_size != nullptr) {
...@@ -194,12 +206,17 @@ RC PersistHandler::append(int size, const char *data, int64_t *out_size) ...@@ -194,12 +206,17 @@ RC PersistHandler::append(int size, const char *data, int64_t *out_size)
rc = RC::FILE_NOT_OPENED; rc = RC::FILE_NOT_OPENED;
} else { } else {
if (lseek(file_desc_, 0, SEEK_END) == off_t(-1)) { if (lseek(file_desc_, 0, SEEK_END) == off_t(-1)) {
LOG_ERROR("Failed to append file %d:%s due to failed to seek: %s.", file_desc_, file_name_.c_str(), strerror(errno)); LOG_ERROR(
"Failed to append file %d:%s due to failed to seek: %s.", file_desc_, file_name_.c_str(), strerror(errno));
rc = RC::FILE_SEEK; rc = RC::FILE_SEEK;
} else { } else {
int64_t write_size = 0; int64_t write_size = 0;
if ((write_size = write(file_desc_, data, size)) != size) { if ((write_size = write(file_desc_, data, size)) != size) {
LOG_ERROR("Failed to append file %d:%s due to %s. Write size: %lld", file_desc_, file_name_.c_str(), strerror(errno), write_size); LOG_ERROR("Failed to append file %d:%s due to %s. Write size: %lld",
file_desc_,
file_name_.c_str(),
strerror(errno),
write_size);
rc = RC::FILE_WRITE; rc = RC::FILE_WRITE;
} }
if (out_size != nullptr) { if (out_size != nullptr) {
...@@ -245,7 +262,11 @@ RC PersistHandler::read_at(uint64_t offset, int size, char *data, int64_t *out_s ...@@ -245,7 +262,11 @@ RC PersistHandler::read_at(uint64_t offset, int size, char *data, int64_t *out_s
rc = RC::FILE_NOT_OPENED; rc = RC::FILE_NOT_OPENED;
} else { } else {
if (lseek(file_desc_, offset, SEEK_SET) == off_t(-1)) { if (lseek(file_desc_, offset, SEEK_SET) == off_t(-1)) {
LOG_ERROR("Failed to read %llu of %d:%s due to failed to seek %s.", offset, file_desc_, file_name_.c_str(), strerror(errno)); LOG_ERROR("Failed to read %llu of %d:%s due to failed to seek %s.",
offset,
file_desc_,
file_name_.c_str(),
strerror(errno));
return RC::FILE_SEEK; return RC::FILE_SEEK;
} else { } else {
int64_t read_size = 0; int64_t read_size = 0;
......
...@@ -24,8 +24,7 @@ See the Mulan PSL v2 for more details. */ ...@@ -24,8 +24,7 @@ See the Mulan PSL v2 for more details. */
#include "rc.h" #include "rc.h"
class PersistHandler class PersistHandler {
{
public: public:
PersistHandler(); PersistHandler();
~PersistHandler(); ~PersistHandler();
...@@ -61,8 +60,8 @@ public: ...@@ -61,8 +60,8 @@ public:
RC seek(uint64_t offset); RC seek(uint64_t offset);
private: private:
std::string file_name_; std::string file_name_;
int file_desc_ = -1; int file_desc_ = -1;
}; };
#endif //__OBSERVER_STORAGE_PERSIST_HANDLER_H_ #endif //__OBSERVER_STORAGE_PERSIST_HANDLER_H_
...@@ -32,8 +32,7 @@ struct RID { ...@@ -32,8 +32,7 @@ struct RID {
// bool valid; // true means a valid record // bool valid; // true means a valid record
RID() = default; RID() = default;
RID(const PageNum _page_num, const SlotNum _slot_num) RID(const PageNum _page_num, const SlotNum _slot_num) : page_num(_page_num), slot_num(_slot_num)
: page_num(_page_num), slot_num(_slot_num)
{} {}
const std::string to_string() const const std::string to_string() const
...@@ -82,25 +81,46 @@ struct RID { ...@@ -82,25 +81,46 @@ struct RID {
} }
}; };
class Record class Record {
{
public: public:
Record() = default; Record() = default;
~Record() = default; ~Record() = default;
void set_data(char *data) { this->data_ = data; } void set_data(char *data)
char *data() { return this->data_; } {
const char *data() const { return this->data_; } this->data_ = data;
}
char *data()
{
return this->data_;
}
const char *data() const
{
return this->data_;
}
void set_rid(const RID &rid) { this->rid_ = rid; } void set_rid(const RID &rid)
void set_rid(const PageNum page_num, const SlotNum slot_num) { this->rid_.page_num = page_num; this->rid_.slot_num = slot_num; } {
RID & rid() { return rid_; } this->rid_ = rid;
const RID &rid() const { return rid_; }; }
void set_rid(const PageNum page_num, const SlotNum slot_num)
{
this->rid_.page_num = page_num;
this->rid_.slot_num = slot_num;
}
RID &rid()
{
return rid_;
}
const RID &rid() const
{
return rid_;
};
private: private:
RID rid_; RID rid_;
// the data buffer // the data buffer
// record will not release the memory // record will not release the memory
char * data_ = nullptr; char *data_ = nullptr;
}; };
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册