未验证 提交 3fa7cfc4 编写于 作者: Y Yang Zhao 提交者: GitHub

fix: taos shell support ws client (#13974)

* fix: taos shell support ws client

* fix: ping/pong

* fix: shell ctrl c stop query
上级 cf9ffe30
......@@ -40,6 +40,18 @@ typedef struct SShellHistory {
int hend;
} SShellHistory;
typedef enum enumWebSocketFrameType {
TEXT_FRAME = 0x81,
PING_FRAME = 0x19,
PONG_FRAME = 0x8A,
} WebSocketFrameType;
typedef struct SWSParser {
int offset;
int payload_length;
WebSocketFrameType frame;
} SWSParser;
typedef struct SShellArguments {
char* host;
char* password;
......@@ -75,7 +87,13 @@ typedef struct SShellArguments {
char* cloudToken;
} SShellArguments;
typedef enum WS_ACTION_TYPE_S { WS_CONN, WS_QUERY, WS_FETCH, WS_FETCH_BLOCK } WS_ACTION_TYPE;
typedef enum WS_ACTION_TYPE_S {
WS_CONN,
WS_QUERY,
WS_FETCH,
WS_FETCH_BLOCK,
WS_CLOSE,
} WS_ACTION_TYPE;
/**************** Function declarations ****************/
extern void shellParseArgument(int argc, char* argv[], SShellArguments* arguments);
......@@ -101,6 +119,7 @@ int isCommentLine(char* line);
int wsclient_handshake();
int wsclient_conn();
void wsclient_query(char* command);
int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int64_t id);
int tcpConnect(char* host, int port);
int parse_cloud_dsn();
......@@ -115,5 +134,7 @@ extern int get_old_terminal_mode(struct termios* tio);
extern void reset_terminal_mode();
extern SShellArguments args;
extern int64_t result;
extern int64_t ws_id;
extern bool stop_fetch;
#endif
......@@ -136,7 +136,7 @@ void shellInit(SShellArguments *_args) {
exit(EXIT_SUCCESS);
}
#endif
return;
}
......@@ -1210,12 +1210,11 @@ int wsclient_handshake() {
return 0;
}
int wsclient_send(char *strdata) {
int wsclient_send(char *strdata, WebSocketFrameType frame) {
struct timeval tv;
unsigned char mask[4];
unsigned int mask_int;
unsigned long long payload_len;
unsigned char finNopcode;
unsigned int payload_len_small;
unsigned int payload_offset = 6;
unsigned int len_size;
......@@ -1229,7 +1228,6 @@ int wsclient_send(char *strdata) {
mask_int = rand();
memcpy(mask, &mask_int, 4);
payload_len = strlen(strdata);
finNopcode = 0x81;
if (payload_len <= 125) {
frame_size = 6 + payload_len;
payload_len_small = payload_len;
......@@ -1247,7 +1245,7 @@ int wsclient_send(char *strdata) {
}
data = (char *)malloc(frame_size);
memset(data, 0, frame_size);
*data = finNopcode;
*data = frame;
*(data + 1) = payload_len_small | 0x80;
if (payload_len_small == 126) {
payload_len &= 0xffff;
......@@ -1280,7 +1278,8 @@ int wsclient_send(char *strdata) {
return 0;
}
int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int id) {
int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int64_t id) {
int code = 1;
cJSON *json = cJSON_CreateObject();
cJSON *_args = cJSON_CreateObject();
cJSON_AddNumberToObject(_args, "req_id", 1);
......@@ -1304,15 +1303,22 @@ int wsclient_send_sql(char *command, WS_ACTION_TYPE type, int id) {
cJSON_AddStringToObject(json, "action", "fetch_block");
cJSON_AddNumberToObject(_args, "id", id);
break;
case WS_CLOSE:
cJSON_AddStringToObject(json, "action", "close");
cJSON_AddNumberToObject(_args, "id", id);
break;
}
cJSON_AddItemToObject(json, "args", _args);
char *strdata = NULL;
strdata = cJSON_Print(json);
if (wsclient_send(strdata)) {
free(strdata);
return -1;
}
return 0;
if (wsclient_send(strdata, TEXT_FRAME)) {
goto OVER;
}
code = 0;
OVER:
free(strdata);
cJSON_Delete(json);
return code;
}
int wsclient_conn() {
......@@ -1326,7 +1332,7 @@ int wsclient_conn() {
fprintf(stderr, "failed to receive from socket\n");
return -1;
}
char *received_json = strstr(recv_buffer, "{");
cJSON *root = cJSON_Parse(received_json);
if (root == NULL) {
......@@ -1362,45 +1368,71 @@ int wsclient_conn() {
return -1;
}
cJSON *wsclient_parse_response() {
char *recv_buffer = calloc(1, 4096);
int start = 0;
bool found = false;
int received = 0;
int bytes;
int recv_length = 4095;
do {
bytes = recv(args.socket, recv_buffer + received, recv_length - received, 0);
if (bytes == -1) {
free(recv_buffer);
fprintf(stderr, "websocket recv failed with bytes: %d\n", bytes);
return NULL;
}
if (!found) {
for (; start < recv_length - received; start++) {
if ((recv_buffer + start)[0] == '{') {
found = true;
break;
}
}
void wsclient_parse_frame(SWSParser * parser, uint8_t * recv_buffer) {
unsigned char msg_opcode = recv_buffer[0] & 0x0F;
unsigned char msg_masked = (recv_buffer[1] >> 7) & 0x01;
int payload_length = 0;
int pos = 2;
int length_field = recv_buffer[1] &(~0x80);
unsigned int mask = 0;
if (length_field <= 125) {
payload_length = length_field;
} else if (length_field == 126) {
payload_length = recv_buffer[2];
for (int i = 0; i < 1; i++) {
payload_length = (payload_length << 8) + recv_buffer[3 + i];
}
if (NULL != strstr(recv_buffer + start, "}")) {
break;
pos += 2;
} else if (length_field == 127) {
payload_length = recv_buffer[2];
for (int i = 0; i < 7; i++) {
payload_length = (payload_length << 8) + recv_buffer[3 + i];
}
received += bytes;
if (received >= recv_length) {
recv_length += 4096;
recv_buffer = realloc(recv_buffer + start, recv_length);
pos += 8;
}
if (msg_masked) {
mask = *((unsigned int *) (recv_buffer + pos));
pos += 4;
const uint8_t *c = recv_buffer + pos;
for (int i = 0; i < payload_length; i++) {
recv_buffer[i] = c[i] ^ ((unsigned char *) (&mask))[i % 4];
}
} while (1);
cJSON *res = cJSON_Parse(recv_buffer + start);
if (res == NULL) {
fprintf(stderr, "fail to parse response into json: %s\n", recv_buffer + start);
free(recv_buffer);
}
if (msg_opcode == 0x9) {
parser->frame = PING_FRAME;
}
parser->offset = pos;
parser->payload_length = payload_length;
}
char *wsclient_get_response() {
uint8_t recv_buffer[1024]= {0};
int received = 0;
SWSParser parser;
int bytes = recv(args.socket, recv_buffer + received, 1023, 0);
if (bytes <= 0) {
fprintf(stderr, "websocket recv failed with bytes: %d\n", bytes);
return NULL;
}
return res;
wsclient_parse_frame(&parser, recv_buffer);
if (parser.frame == PING_FRAME) {
if (wsclient_send("pong", PONG_FRAME)) {
return NULL;
}
return wsclient_get_response();
}
char* response = calloc(1, parser.payload_length + 1);
int pos = bytes - parser.offset;
memcpy(response, recv_buffer + parser.offset, pos);
while (pos < parser.payload_length) {
bytes = recv(args.socket, response + pos, parser.payload_length - pos, 0);
pos += bytes;
}
response[pos] = '\0';
if (NULL != strstr(response, "unexpected")) {
printf("motherfucker");
}
return response;
}
TAOS_FIELD *wsclient_print_header(cJSON *query, int *pcols, int *pprecison) {
......@@ -1462,43 +1494,33 @@ int wsclient_check(cJSON *root, int64_t st, int64_t et) {
}
int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int precision, int* pshowed_rows) {
char *recv_buffer = calloc(1, 4096);
int col_length = 0;
for (int i = 0; i < cols; i++) {
col_length += fields[i].bytes;
}
int total_recv_len = col_length * rows + 12;
int received = 0;
int recv_length = 4095;
int start = 0;
int pos;
do {
int bytes = recv(args.socket, recv_buffer + received, recv_length - received, 0);
received += bytes;
if (received >= recv_length) {
recv_length += 4096;
recv_buffer = realloc(recv_buffer, recv_length);
}
} while (received < total_recv_len);
char* response = wsclient_get_response();
if (response == NULL) {
return 0;
}
while (1) {
if (*(int64_t *)(recv_buffer + start) == id) {
break;
}
start++;
if (*(int64_t *)response != id) {
fprintf(stderr, "Mismatch id with %"PRId64" expect %"PRId64"\n", *(int64_t *)response, id);
free(response);
return 0;
}
start += 8;
int pos;
int width[TSDB_MAX_COLUMNS];
for (int c = 0; c < cols; c++) {
width[c] = calcColWidth(fields + c, precision);
}
for (int i = 0; i < rows; i++) {
if (*pshowed_rows == DEFAULT_RES_SHOW_NUM) {
free(recv_buffer);
printf("\n");
printf(" Notice: The result shows only the first %d rows.\n", DEFAULT_RES_SHOW_NUM);
printf("\n");
printf(" You can use Ctrl+C to stop the underway fetching.\n");
printf("\n");
free(response);
return 0;
}
}
for (int c = 0; c < cols; c++) {
pos = start;
pos = 8;
pos += i * fields[c].bytes;
for (int j = 0; j < c; j++) {
pos += fields[j].bytes * rows;
......@@ -1507,17 +1529,17 @@ int wsclient_print_data(int rows, TAOS_FIELD *fields, int cols, int64_t id, int
int16_t length = 0;
if (fields[c].type == TSDB_DATA_TYPE_NCHAR || fields[c].type == TSDB_DATA_TYPE_BINARY ||
fields[c].type == TSDB_DATA_TYPE_JSON) {
length = *(int16_t *)(recv_buffer + pos);
length = *(int16_t *)(response + pos);
pos += 2;
}
printField((const char *)(recv_buffer + pos), fields + c, width[c], (int32_t)length, precision);
printField((const char *)(response + pos), fields + c, width[c], (int32_t)length, precision);
putchar(' ');
putchar('|');
}
putchar('\n');
*pshowed_rows += 1;
}
free(recv_buffer);
free(response);
return 0;
}
......@@ -1528,13 +1550,20 @@ void wsclient_query(char *command) {
return;
}
et = taosGetTimestampUs();
cJSON *query = wsclient_parse_response();
char *query_buffer = wsclient_get_response();
if (query_buffer == NULL) {
return;
}
cJSON* query = cJSON_Parse(query_buffer);
if (query == NULL) {
free(query_buffer);
fprintf(stderr, "Failed to parse response into json: %s\n", query_buffer);
return;
}
et = taosGetTimestampUs();
free(query_buffer);
if (wsclient_check(query, st, et)) {
goto OVER;
return;
}
cJSON *is_update = cJSON_GetObjectItem(query, "is_update");
......@@ -1542,6 +1571,7 @@ void wsclient_query(char *command) {
if (is_update->valueint) {
cJSON *affected_rows = cJSON_GetObjectItem(query, "affected_rows");
if (cJSON_IsNumber(affected_rows)) {
et = taosGetTimestampUs();
printf("Update OK, %d row(s) in set (%.6fs)\n\n", (int)affected_rows->valueint, (et - st) / 1E6);
} else {
fprintf(stderr, "Invalid affected_rows key in json\n");
......@@ -1555,15 +1585,19 @@ void wsclient_query(char *command) {
if (fields != NULL) {
cJSON *id = cJSON_GetObjectItem(query, "id");
if (cJSON_IsNumber(id)) {
ws_id = id->valueint;
bool completed = false;
while (!completed) {
if (wsclient_send_sql(NULL, WS_FETCH, (int)id->valueint) == 0) {
cJSON *fetch = wsclient_parse_response();
while (!completed && !stop_fetch) {
if (wsclient_send_sql(NULL, WS_FETCH, id->valueint) == 0) {
char *fetch_buffer = wsclient_get_response();
cJSON* fetch = cJSON_Parse(fetch_buffer);
if (fetch != NULL) {
free(fetch_buffer);
if (wsclient_check(fetch, st, et) == 0) {
cJSON *_completed = cJSON_GetObjectItem(fetch, "completed");
if (cJSON_IsBool(_completed)) {
if (_completed->valueint) {
cJSON_Delete(fetch);
completed = true;
continue;
}
......@@ -1576,10 +1610,11 @@ void wsclient_query(char *command) {
fields[i].bytes = (int16_t)(cJSON_GetArrayItem(lengths, i)->valueint);
}
if (showed_rows < DEFAULT_RES_SHOW_NUM) {
if (wsclient_send_sql(NULL, WS_FETCH_BLOCK, (int)id->valueint) == 0) {
if (wsclient_send_sql(NULL, WS_FETCH_BLOCK, id->valueint) == 0) {
wsclient_print_data((int)rows->valueint, fields, cols, id->valueint, precision, &showed_rows);
}
}
cJSON_Delete(fetch);
continue;
} else {
fprintf(stderr, "Invalid lengths key in json\n");
......@@ -1591,17 +1626,19 @@ void wsclient_query(char *command) {
fprintf(stderr, "Invalid completed key in json\n");
}
}
cJSON_Delete(fetch);
} else {
fprintf(stderr, "failed to parse response into json: %s\n", fetch_buffer);
free(fetch_buffer);
break;
}
}
fprintf(stderr, "err occured in fetch/fetch_block ws actions\n");
break;
}
if (showed_rows == DEFAULT_RES_SHOW_NUM) {
printf("\n");
printf(" Notice: The result shows only the first %d rows.\n", DEFAULT_RES_SHOW_NUM);
printf("\n");
}
et = taosGetTimestampUs();
printf("Query OK, %" PRId64 " row(s) in set (%.6fs)\n\n", total_rows, (et - st) / 1E6);
stop_fetch = false;
} else {
fprintf(stderr, "Invalid id key in json\n");
}
......@@ -1611,6 +1648,7 @@ void wsclient_query(char *command) {
} else {
fprintf(stderr, "Invalid is_update key in json\n");
}
OVER:
cJSON_Delete(query);
return;
}
\ No newline at end of file
......@@ -20,6 +20,8 @@
pthread_t pid;
static tsem_t cancelSem;
bool stop_fetch = false;
int64_t ws_id = 0;
void shellQueryInterruptHandler(int32_t signum, void *sigInfo, void *context) {
tsem_post(&cancelSem);
......@@ -33,7 +35,12 @@ void *cancelHandler(void *arg) {
taosMsleep(10);
continue;
}
if (args.restful || args.cloud) {
stop_fetch = true;
if (wsclient_send_sql(NULL, WS_CLOSE, ws_id)) {
exit(EXIT_FAILURE);
}
}
#ifdef LINUX
int64_t rid = atomic_val_compare_exchange_64(&result, result, 0);
SSqlObj* pSql = taosAcquireRef(tscObjRef, rid);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册