未验证 提交 3fa7cfc4 编写于 作者: Y Yang Zhao 提交者: GitHub

fix: taos shell support ws client (#13974)

* fix: taos shell support ws client

* fix: ping/pong

* fix: shell ctrl c stop query
上级 cf9ffe30
...@@ -40,6 +40,18 @@ typedef struct SShellHistory { ...@@ -40,6 +40,18 @@ typedef struct SShellHistory {
int hend; int hend;
} SShellHistory; } SShellHistory;
typedef enum enumWebSocketFrameType {
TEXT_FRAME = 0x81,
PING_FRAME = 0x19,
PONG_FRAME = 0x8A,
} WebSocketFrameType;
typedef struct SWSParser {
int offset;
int payload_length;
WebSocketFrameType frame;
} SWSParser;
typedef struct SShellArguments { typedef struct SShellArguments {
char* host; char* host;
char* password; char* password;
...@@ -75,7 +87,13 @@ typedef struct SShellArguments { ...@@ -75,7 +87,13 @@ typedef struct SShellArguments {
char* cloudToken; char* cloudToken;
} SShellArguments; } SShellArguments;
typedef enum WS_ACTION_TYPE_S { WS_CONN, WS_QUERY, WS_FETCH, WS_FETCH_BLOCK } WS_ACTION_TYPE; typedef enum WS_ACTION_TYPE_S {
WS_CONN,
WS_QUERY,
WS_FETCH,
WS_FETCH_BLOCK,
WS_CLOSE,
} WS_ACTION_TYPE;
/**************** Function declarations ****************/ /**************** Function declarations ****************/
extern void shellParseArgument(int argc, char* argv[], SShellArguments* arguments); extern void shellParseArgument(int argc, char* argv[], SShellArguments* arguments);
...@@ -101,6 +119,7 @@ int isCommentLine(char* line); ...@@ -101,6 +119,7 @@ int isCommentLine(char* line);
int wsclient_handshake(); int wsclient_handshake();
int wsclient_conn(); int wsclient_conn();
void wsclient_query(char* command); void wsclient_query(char* command);
int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int64_t id);
int tcpConnect(char* host, int port); int tcpConnect(char* host, int port);
int parse_cloud_dsn(); int parse_cloud_dsn();
...@@ -115,5 +134,7 @@ extern int get_old_terminal_mode(struct termios* tio); ...@@ -115,5 +134,7 @@ extern int get_old_terminal_mode(struct termios* tio);
extern void reset_terminal_mode(); extern void reset_terminal_mode();
extern SShellArguments args; extern SShellArguments args;
extern int64_t result; extern int64_t result;
extern int64_t ws_id;
extern bool stop_fetch;
#endif #endif
...@@ -1210,12 +1210,11 @@ int wsclient_handshake() { ...@@ -1210,12 +1210,11 @@ int wsclient_handshake() {
return 0; return 0;
} }
int wsclient_send(char *strdata) { int wsclient_send(char *strdata, WebSocketFrameType frame) {
struct timeval tv; struct timeval tv;
unsigned char mask[4]; unsigned char mask[4];
unsigned int mask_int; unsigned int mask_int;
unsigned long long payload_len; unsigned long long payload_len;
unsigned char finNopcode;
unsigned int payload_len_small; unsigned int payload_len_small;
unsigned int payload_offset = 6; unsigned int payload_offset = 6;
unsigned int len_size; unsigned int len_size;
...@@ -1229,7 +1228,6 @@ int wsclient_send(char *strdata) { ...@@ -1229,7 +1228,6 @@ int wsclient_send(char *strdata) {
mask_int = rand(); mask_int = rand();
memcpy(mask, &mask_int, 4); memcpy(mask, &mask_int, 4);
payload_len = strlen(strdata); payload_len = strlen(strdata);
finNopcode = 0x81;
if (payload_len <= 125) { if (payload_len <= 125) {
frame_size = 6 + payload_len; frame_size = 6 + payload_len;
payload_len_small = payload_len; payload_len_small = payload_len;
...@@ -1247,7 +1245,7 @@ int wsclient_send(char *strdata) { ...@@ -1247,7 +1245,7 @@ int wsclient_send(char *strdata) {
} }
data = (char *)malloc(frame_size); data = (char *)malloc(frame_size);
memset(data, 0, frame_size); memset(data, 0, frame_size);
*data = finNopcode; *data = frame;
*(data + 1) = payload_len_small | 0x80; *(data + 1) = payload_len_small | 0x80;
if (payload_len_small == 126) { if (payload_len_small == 126) {
payload_len &= 0xffff; payload_len &= 0xffff;
...@@ -1280,7 +1278,8 @@ int wsclient_send(char *strdata) { ...@@ -1280,7 +1278,8 @@ int wsclient_send(char *strdata) {
return 0; return 0;
} }
int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int id) { int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int64_t id) {
int code = 1;
cJSON *json = cJSON_CreateObject(); cJSON *json = cJSON_CreateObject();
cJSON *_args = cJSON_CreateObject(); cJSON *_args = cJSON_CreateObject();
cJSON_AddNumberToObject(_args, "req_id", 1); cJSON_AddNumberToObject(_args, "req_id", 1);
...@@ -1304,15 +1303,22 @@ int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int id) { ...@@ -1304,15 +1303,22 @@ int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int id) {
cJSON_AddStringToObject(json, "action", "fetch_block"); cJSON_AddStringToObject(json, "action", "fetch_block");
cJSON_AddNumberToObject(_args, "id", id); cJSON_AddNumberToObject(_args, "id", id);
break; break;
case WS_CLOSE:
cJSON_AddStringToObject(json, "action", "close");
cJSON_AddNumberToObject(_args, "id", id);
break;
} }
cJSON_AddItemToObject(json, "args", _args); cJSON_AddItemToObject(json, "args", _args);
char *strdata = NULL; char *strdata = NULL;
strdata = cJSON_Print(json); strdata = cJSON_Print(json);
if (wsclient_send(strdata)) { if (wsclient_send(strdata, TEXT_FRAME)) {
free(strdata); goto OVER;
return -1;
} }
return 0; code = 0;
OVER:
free(strdata);
cJSON_Delete(json);
return code;
} }
int wsclient_conn() { int wsclient_conn() {
...@@ -1362,45 +1368,71 @@ int wsclient_conn() { ...@@ -1362,45 +1368,71 @@ int wsclient_conn() {
return -1; return -1;
} }
cJSON *wsclient_parse_response() { void wsclient_parse_frame(SWSParser * parser, uint8_t * recv_buffer) {
char *recv_buffer = calloc(1, 4096); unsigned char msg_opcode = recv_buffer[0] & 0x0F;
int start = 0; unsigned char msg_masked = (recv_buffer[1] >> 7) & 0x01;
bool found = false; int payload_length = 0;
int received = 0; int pos = 2;
int bytes; int length_field = recv_buffer[1] &(~0x80);
int recv_length = 4095; unsigned int mask = 0;
do { if (length_field <= 125) {
bytes = recv(args.socket, recv_buffer + received, recv_length - received, 0); payload_length = length_field;
if (bytes == -1) { } else if (length_field == 126) {
free(recv_buffer); payload_length = recv_buffer[2];
fprintf(stderr, "websocket recv failed with bytes: %d\n", bytes); for (int i = 0; i < 1; i++) {
return NULL; payload_length = (payload_length << 8) + recv_buffer[3 + i];
} }
pos += 2;
if (!found) { } else if (length_field == 127) {
for (; start < recv_length - received; start++) { payload_length = recv_buffer[2];
if ((recv_buffer + start)[0] == '{') { for (int i = 0; i < 7; i++) {
found = true; payload_length = (payload_length << 8) + recv_buffer[3 + i];
break;
} }
pos += 8;
} }
if (msg_masked) {
mask = *((unsigned int *) (recv_buffer + pos));
pos += 4;
const uint8_t *c = recv_buffer + pos;
for (int i = 0; i < payload_length; i++) {
recv_buffer[i] = c[i] ^ ((unsigned char *) (&mask))[i % 4];
} }
if (NULL != strstr(recv_buffer + start, "}")) {
break;
} }
received += bytes; if (msg_opcode == 0x9) {
if (received >= recv_length) { parser->frame = PING_FRAME;
recv_length += 4096;
recv_buffer = realloc(recv_buffer + start, recv_length);
} }
} while (1); parser->offset = pos;
cJSON *res = cJSON_Parse(recv_buffer + start); parser->payload_length = payload_length;
if (res == NULL) { }
fprintf(stderr, "fail to parse response into json: %s\n", recv_buffer + start);
free(recv_buffer); char *wsclient_get_response() {
uint8_t recv_buffer[1024]= {0};
int received = 0;
SWSParser parser;
int bytes = recv(args.socket, recv_buffer + received, 1023, 0);
if (bytes <= 0) {
fprintf(stderr, "websocket recv failed with bytes: %d\n", bytes);
return NULL;
}
wsclient_parse_frame(&parser, recv_buffer);
if (parser.frame == PING_FRAME) {
if (wsclient_send("pong", PONG_FRAME)) {
return NULL; return NULL;
} }
return res; return wsclient_get_response();
}
char* response = calloc(1, parser.payload_length + 1);
int pos = bytes - parser.offset;
memcpy(response, recv_buffer + parser.offset, pos);
while (pos < parser.payload_length) {
bytes = recv(args.socket, response + pos, parser.payload_length - pos, 0);
pos += bytes;
}
response[pos] = '\0';
if (NULL != strstr(response, "unexpected")) {
printf("motherfucker");
}
return response;
} }
TAOS_FIELD *wsclient_print_header(cJSON *query, int *pcols, int *pprecison) { TAOS_FIELD *wsclient_print_header(cJSON *query, int *pcols, int *pprecison) {
...@@ -1462,43 +1494,33 @@ int wsclient_check(cJSON *root, int64_t st, int64_t et) { ...@@ -1462,43 +1494,33 @@ int wsclient_check(cJSON *root, int64_t st, int64_t et) {
} }
int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int precision, int* pshowed_rows) { int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int precision, int* pshowed_rows) {
char *recv_buffer = calloc(1, 4096); char* response = wsclient_get_response();
int col_length = 0; if (response == NULL) {
for (int i = 0; i < cols; i++) { return 0;
col_length += fields[i].bytes;
}
int total_recv_len = col_length * rows + 12;
int received = 0;
int recv_length = 4095;
int start = 0;
int pos;
do {
int bytes = recv(args.socket, recv_buffer + received, recv_length - received, 0);
received += bytes;
if (received >= recv_length) {
recv_length += 4096;
recv_buffer = realloc(recv_buffer, recv_length);
} }
} while (received < total_recv_len);
while (1) { if (*(int64_t *)response != id) {
if (*(int64_t *)(recv_buffer + start) == id) { fprintf(stderr, "Mismatch id with %"PRId64" expect %"PRId64"\n", *(int64_t *)response, id);
break; free(response);
} return 0;
start++;
} }
start += 8; int pos;
int width[TSDB_MAX_COLUMNS]; int width[TSDB_MAX_COLUMNS];
for (int c = 0; c < cols; c++) { for (int c = 0; c < cols; c++) {
width[c] = calcColWidth(fields + c, precision); width[c] = calcColWidth(fields + c, precision);
} }
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
if (*pshowed_rows == DEFAULT_RES_SHOW_NUM) { if (*pshowed_rows == DEFAULT_RES_SHOW_NUM) {
free(recv_buffer); printf("\n");
printf(" Notice: The result shows only the first %d rows.\n", DEFAULT_RES_SHOW_NUM);
printf("\n");
printf(" You can use Ctrl+C to stop the underway fetching.\n");
printf("\n");
free(response);
return 0; return 0;
} }
for (int c = 0; c < cols; c++) { for (int c = 0; c < cols; c++) {
pos = start; pos = 8;
pos += i * fields[c].bytes; pos += i * fields[c].bytes;
for (int j = 0; j < c; j++) { for (int j = 0; j < c; j++) {
pos += fields[j].bytes * rows; pos += fields[j].bytes * rows;
...@@ -1507,17 +1529,17 @@ int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int ...@@ -1507,17 +1529,17 @@ int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int
int16_t length = 0; int16_t length = 0;
if (fields[c].type == TSDB_DATA_TYPE_NCHAR || fields[c].type == TSDB_DATA_TYPE_BINARY || if (fields[c].type == TSDB_DATA_TYPE_NCHAR || fields[c].type == TSDB_DATA_TYPE_BINARY ||
fields[c].type == TSDB_DATA_TYPE_JSON) { fields[c].type == TSDB_DATA_TYPE_JSON) {
length = *(int16_t *)(recv_buffer + pos); length = *(int16_t *)(response + pos);
pos += 2; pos += 2;
} }
printField((const char *)(recv_buffer + pos), fields + c, width[c], (int32_t)length, precision); printField((const char *)(response + pos), fields + c, width[c], (int32_t)length, precision);
putchar(' '); putchar(' ');
putchar('|'); putchar('|');
} }
putchar('\n'); putchar('\n');
*pshowed_rows += 1; *pshowed_rows += 1;
} }
free(recv_buffer); free(response);
return 0; return 0;
} }
...@@ -1528,13 +1550,20 @@ void wsclient_query(char *command) { ...@@ -1528,13 +1550,20 @@ void wsclient_query(char *command) {
return; return;
} }
et = taosGetTimestampUs(); char *query_buffer = wsclient_get_response();
cJSON *query = wsclient_parse_response(); if (query_buffer == NULL) {
return;
}
cJSON* query = cJSON_Parse(query_buffer);
if (query == NULL) { if (query == NULL) {
free(query_buffer);
fprintf(stderr, "Failed to parse response into json: %s\n", query_buffer);
return; return;
} }
et = taosGetTimestampUs();
free(query_buffer);
if (wsclient_check(query, st, et)) { if (wsclient_check(query, st, et)) {
goto OVER;
return; return;
} }
cJSON *is_update = cJSON_GetObjectItem(query, "is_update"); cJSON *is_update = cJSON_GetObjectItem(query, "is_update");
...@@ -1542,6 +1571,7 @@ void wsclient_query(char *command) { ...@@ -1542,6 +1571,7 @@ void wsclient_query(char *command) {
if (is_update->valueint) { if (is_update->valueint) {
cJSON *affected_rows = cJSON_GetObjectItem(query, "affected_rows"); cJSON *affected_rows = cJSON_GetObjectItem(query, "affected_rows");
if (cJSON_IsNumber(affected_rows)) { if (cJSON_IsNumber(affected_rows)) {
et = taosGetTimestampUs();
printf("Update OK, %d row(s) in set (%.6fs)\n\n", (int)affected_rows->valueint, (et - st) / 1E6); printf("Update OK, %d row(s) in set (%.6fs)\n\n", (int)affected_rows->valueint, (et - st) / 1E6);
} else { } else {
fprintf(stderr, "Invalid affected_rows key in json\n"); fprintf(stderr, "Invalid affected_rows key in json\n");
...@@ -1555,15 +1585,19 @@ void wsclient_query(char *command) { ...@@ -1555,15 +1585,19 @@ void wsclient_query(char *command) {
if (fields != NULL) { if (fields != NULL) {
cJSON *id = cJSON_GetObjectItem(query, "id"); cJSON *id = cJSON_GetObjectItem(query, "id");
if (cJSON_IsNumber(id)) { if (cJSON_IsNumber(id)) {
ws_id = id->valueint;
bool completed = false; bool completed = false;
while (!completed) { while (!completed && !stop_fetch) {
if (wsclient_send_sql(NULL, WS_FETCH, (int)id->valueint) == 0) { if (wsclient_send_sql(NULL, WS_FETCH, id->valueint) == 0) {
cJSON *fetch = wsclient_parse_response(); char *fetch_buffer = wsclient_get_response();
cJSON* fetch = cJSON_Parse(fetch_buffer);
if (fetch != NULL) { if (fetch != NULL) {
free(fetch_buffer);
if (wsclient_check(fetch, st, et) == 0) { if (wsclient_check(fetch, st, et) == 0) {
cJSON *_completed = cJSON_GetObjectItem(fetch, "completed"); cJSON *_completed = cJSON_GetObjectItem(fetch, "completed");
if (cJSON_IsBool(_completed)) { if (cJSON_IsBool(_completed)) {
if (_completed->valueint) { if (_completed->valueint) {
cJSON_Delete(fetch);
completed = true; completed = true;
continue; continue;
} }
...@@ -1576,10 +1610,11 @@ void wsclient_query(char *command) { ...@@ -1576,10 +1610,11 @@ void wsclient_query(char *command) {
fields[i].bytes = (int16_t)(cJSON_GetArrayItem(lengths, i)->valueint); fields[i].bytes = (int16_t)(cJSON_GetArrayItem(lengths, i)->valueint);
} }
if (showed_rows < DEFAULT_RES_SHOW_NUM) { if (showed_rows < DEFAULT_RES_SHOW_NUM) {
if (wsclient_send_sql(NULL, WS_FETCH_BLOCK, (int)id->valueint) == 0) { if (wsclient_send_sql(NULL, WS_FETCH_BLOCK, id->valueint) == 0) {
wsclient_print_data((int)rows->valueint, fields, cols, id->valueint, precision, &showed_rows); wsclient_print_data((int)rows->valueint, fields, cols, id->valueint, precision, &showed_rows);
} }
} }
cJSON_Delete(fetch);
continue; continue;
} else { } else {
fprintf(stderr, "Invalid lengths key in json\n"); fprintf(stderr, "Invalid lengths key in json\n");
...@@ -1591,17 +1626,19 @@ void wsclient_query(char *command) { ...@@ -1591,17 +1626,19 @@ void wsclient_query(char *command) {
fprintf(stderr, "Invalid completed key in json\n"); fprintf(stderr, "Invalid completed key in json\n");
} }
} }
cJSON_Delete(fetch);
} else {
fprintf(stderr, "failed to parse response into json: %s\n", fetch_buffer);
free(fetch_buffer);
break;
} }
} }
fprintf(stderr, "err occured in fetch/fetch_block ws actions\n"); fprintf(stderr, "err occured in fetch/fetch_block ws actions\n");
break; break;
} }
if (showed_rows == DEFAULT_RES_SHOW_NUM) { et = taosGetTimestampUs();
printf("\n");
printf(" Notice: The result shows only the first %d rows.\n", DEFAULT_RES_SHOW_NUM);
printf("\n");
}
printf("Query OK, %" PRId64 " row(s) in set (%.6fs)\n\n", total_rows, (et - st) / 1E6); printf("Query OK, %" PRId64 " row(s) in set (%.6fs)\n\n", total_rows, (et - st) / 1E6);
stop_fetch = false;
} else { } else {
fprintf(stderr, "Invalid id key in json\n"); fprintf(stderr, "Invalid id key in json\n");
} }
...@@ -1611,6 +1648,7 @@ void wsclient_query(char *command) { ...@@ -1611,6 +1648,7 @@ void wsclient_query(char *command) {
} else { } else {
fprintf(stderr, "Invalid is_update key in json\n"); fprintf(stderr, "Invalid is_update key in json\n");
} }
OVER:
cJSON_Delete(query); cJSON_Delete(query);
return; return;
} }
\ No newline at end of file
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
pthread_t pid; pthread_t pid;
static tsem_t cancelSem; static tsem_t cancelSem;
bool stop_fetch = false;
int64_t ws_id = 0;
void shellQueryInterruptHandler(int32_t signum, void *sigInfo, void *context) { void shellQueryInterruptHandler(int32_t signum, void *sigInfo, void *context) {
tsem_post(&cancelSem); tsem_post(&cancelSem);
...@@ -33,7 +35,12 @@ void *cancelHandler(void *arg) { ...@@ -33,7 +35,12 @@ void *cancelHandler(void *arg) {
taosMsleep(10); taosMsleep(10);
continue; continue;
} }
if (args.restful || args.cloud) {
stop_fetch = true;
if (wsclient_send_sql(NULL, WS_CLOSE, ws_id)) {
exit(EXIT_FAILURE);
}
}
#ifdef LINUX #ifdef LINUX
int64_t rid = atomic_val_compare_exchange_64(&result, result, 0); int64_t rid = atomic_val_compare_exchange_64(&result, result, 0);
SSqlObj* pSql = taosAcquireRef(tscObjRef, rid); SSqlObj* pSql = taosAcquireRef(tscObjRef, rid);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册