未验证 提交 699d10bd 编写于 作者: Y Yusuke Sakurai 提交者: GitHub

fix: make WebSocket.send() exclusive (#3885)

上级 ed680552
......@@ -10,6 +10,7 @@ import { readLong, readShort, sliceLongToBytes } from "../io/ioutil.ts";
import { Sha1 } from "./sha1.ts";
import { writeResponse } from "../http/server.ts";
import { TextProtoReader } from "../textproto/mod.ts";
import { Deferred, deferred } from "../util/async.ts";
export enum OpCode {
Continue = 0x0,
......@@ -193,21 +194,30 @@ function createMask(): Uint8Array {
}
class WebSocketImpl implements WebSocket {
readonly conn: Conn;
private readonly mask?: Uint8Array;
private readonly bufReader: BufReader;
private readonly bufWriter: BufWriter;
private sendQueue: Array<{
frame: WebSocketFrame;
d: Deferred<void>;
}> = [];
constructor(
readonly conn: Conn,
opts: {
bufReader?: BufReader;
bufWriter?: BufWriter;
mask?: Uint8Array;
}
) {
this.mask = opts.mask;
this.bufReader = opts.bufReader || new BufReader(conn);
this.bufWriter = opts.bufWriter || new BufWriter(conn);
constructor({
conn,
bufReader,
bufWriter,
mask
}: {
conn: Conn;
bufReader?: BufReader;
bufWriter?: BufWriter;
mask?: Uint8Array;
}) {
this.conn = conn;
this.mask = mask;
this.bufReader = bufReader || new BufReader(conn);
this.bufWriter = bufWriter || new BufWriter(conn);
}
async *receive(): AsyncIterableIterator<WebSocketEvent> {
......@@ -250,14 +260,11 @@ class WebSocketImpl implements WebSocket {
yield { code, reason };
return;
case OpCode.Ping:
await writeFrame(
{
opcode: OpCode.Pong,
payload: frame.payload,
isLastFrame: true
},
this.bufWriter
);
await this.enqueue({
opcode: OpCode.Pong,
payload: frame.payload,
isLastFrame: true
});
yield ["ping", frame.payload] as WebSocketPingEvent;
break;
case OpCode.Pong:
......@@ -268,6 +275,27 @@ class WebSocketImpl implements WebSocket {
}
}
private dequeue(): void {
const [e] = this.sendQueue;
if (!e) return;
writeFrame(e.frame, this.bufWriter)
.then(() => e.d.resolve())
.catch(e => e.d.reject(e))
.finally(() => {
this.sendQueue.shift();
this.dequeue();
});
}
private enqueue(frame: WebSocketFrame): Promise<void> {
const d = deferred<void>();
this.sendQueue.push({ d, frame });
if (this.sendQueue.length === 1) {
this.dequeue();
}
return d;
}
async send(data: WebSocketMessage): Promise<void> {
if (this.isClosed) {
throw new SocketClosedError("socket has been closed");
......@@ -276,28 +304,24 @@ class WebSocketImpl implements WebSocket {
typeof data === "string" ? OpCode.TextFrame : OpCode.BinaryFrame;
const payload = typeof data === "string" ? encode(data) : data;
const isLastFrame = true;
await writeFrame(
{
isLastFrame,
opcode,
payload,
mask: this.mask
},
this.bufWriter
);
const frame = {
isLastFrame,
opcode,
payload,
mask: this.mask
};
return this.enqueue(frame);
}
async ping(data: WebSocketMessage = ""): Promise<void> {
const payload = typeof data === "string" ? encode(data) : data;
await writeFrame(
{
isLastFrame: true,
opcode: OpCode.Ping,
mask: this.mask,
payload
},
this.bufWriter
);
const frame = {
isLastFrame: true,
opcode: OpCode.Ping,
mask: this.mask,
payload
};
return this.enqueue(frame);
}
private _isClosed = false;
......@@ -317,15 +341,12 @@ class WebSocketImpl implements WebSocket {
} else {
payload = new Uint8Array(header);
}
await writeFrame(
{
isLastFrame: true,
opcode: OpCode.Close,
mask: this.mask,
payload
},
this.bufWriter
);
await this.enqueue({
isLastFrame: true,
opcode: OpCode.Close,
mask: this.mask,
payload
});
} catch (e) {
throw e;
} finally {
......@@ -380,7 +401,7 @@ export async function acceptWebSocket(req: {
}): Promise<WebSocket> {
const { conn, headers, bufReader, bufWriter } = req;
if (acceptable(req)) {
const sock = new WebSocketImpl(conn, { bufReader, bufWriter });
const sock = new WebSocketImpl({ conn, bufReader, bufWriter });
const secKey = headers.get("sec-websocket-key");
if (typeof secKey !== "string") {
throw new Error("sec-websocket-key is not provided");
......@@ -499,9 +520,19 @@ export async function connectWebSocket(
conn.close();
throw err;
}
return new WebSocketImpl(conn, {
return new WebSocketImpl({
conn,
bufWriter,
bufReader,
mask: createMask()
});
}
export function createWebSocket(params: {
conn: Conn;
bufWriter?: BufWriter;
bufReader?: BufReader;
mask?: Uint8Array;
}): WebSocket {
return new WebSocketImpl(params);
}
......@@ -3,6 +3,7 @@ import { BufReader, BufWriter } from "../io/bufio.ts";
import { assert, assertEquals, assertThrowsAsync } from "../testing/asserts.ts";
import { runIfMain, test } from "../testing/mod.ts";
import { TextProtoReader } from "../textproto/mod.ts";
import * as bytes from "../bytes/mod.ts";
import {
acceptable,
connectWebSocket,
......@@ -11,10 +12,13 @@ import {
OpCode,
readFrame,
unmask,
writeFrame
writeFrame,
createWebSocket
} from "./mod.ts";
import { encode } from "../strings/mod.ts";
import { encode, decode } from "../strings/mod.ts";
type Writer = Deno.Writer;
type Reader = Deno.Reader;
type Conn = Deno.Conn;
const { Buffer } = Deno;
test(async function wsReadUnmaskedTextFrame(): Promise<void> {
......@@ -30,7 +34,7 @@ test(async function wsReadUnmaskedTextFrame(): Promise<void> {
});
test(async function wsReadMaskedTextFrame(): Promise<void> {
//a masked single text frame with payload "Hello"
// a masked single text frame with payload "Hello"
const buf = new BufReader(
new Buffer(
new Uint8Array([
......@@ -272,4 +276,55 @@ test("handshake should send search correctly", async function wsHandshakeWithSea
assertEquals(statusLine, "GET /?a=1 HTTP/1.1");
});
function dummyConn(r: Reader, w: Writer): Conn {
return {
rid: -1,
closeRead: (): void => {},
closeWrite: (): void => {},
read: (x): Promise<number | Deno.EOF> => r.read(x),
write: (x): Promise<number> => w.write(x),
close: (): void => {},
localAddr: { transport: "tcp", hostname: "0.0.0.0", port: 0 },
remoteAddr: { transport: "tcp", hostname: "0.0.0.0", port: 0 }
};
}
function delayedWriter(ms: number, dest: Writer): Writer {
return {
write(p: Uint8Array): Promise<number> {
return new Promise<number>(resolve => {
setTimeout(async (): Promise<void> => {
resolve(await dest.write(p));
}, ms);
});
}
};
}
test("WebSocket.send(), WebSocket.ping() should be exclusive", async (): Promise<
void
> => {
const buf = new Buffer();
const conn = dummyConn(new Buffer(), delayedWriter(1, buf));
const sock = createWebSocket({ conn });
// Ensure send call
await Promise.all([
sock.send("first"),
sock.send("second"),
sock.ping(),
sock.send(new Uint8Array([3]))
]);
const bufr = new BufReader(buf);
const first = await readFrame(bufr);
const second = await readFrame(bufr);
const ping = await readFrame(bufr);
const third = await readFrame(bufr);
assertEquals(first.opcode, OpCode.TextFrame);
assertEquals(decode(first.payload), "first");
assertEquals(first.opcode, OpCode.TextFrame);
assertEquals(decode(second.payload), "second");
assertEquals(ping.opcode, OpCode.Ping);
assertEquals(third.opcode, OpCode.BinaryFrame);
assertEquals(bytes.equal(third.payload, new Uint8Array([3])), true);
});
runIfMain(import.meta);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册