提交 07a88f1e 编写于 作者: Z zhangtian 提交者: Jinliang Li

BugID:26273920:[mdns]improve code,add some value check

Change-Id: I8c831c413a7cda81f1285119a0b88ec31c838a28
上级 96e5b158
......@@ -420,16 +420,17 @@ static void callback_recv(void *p_cookie, int status, const struct mdns_entry *e
case RR_SRV: {
char *name = NULL;
char *type = NULL;
char *last = NULL;
int num = get_char_num(entry->name, '.');
if (num == 2) {
name = UMESH_SRV_DEFAULT_NANE;
char *type = (char *)strtok(entry->name, ".");
char *type = (char *)strtok_r(entry->name, ".", &last);
if (type == NULL || strlen(type) == 0) {
break;
}
} else if (num == 3) {
name = (char *)strtok(entry->name, ".");
type = (char *)strtok(NULL, ".");
name = (char *)strtok_r(entry->name, ".", &last);
type = (char *)strtok_r(NULL, ".", &last);
if (name == NULL || strlen(name) == 0) {
break;
}
......
......@@ -48,15 +48,27 @@ static inline socklen_t ss_len(const struct sockaddr_storage *ss)
: sizeof(struct sockaddr_in6));
}
static inline uint8_t *write_u16(uint8_t *p, const uint16_t v)
static inline uint8_t *write_u16(uint8_t *p, uint16_t *left_len, const uint16_t v)
{
if (left_len != NULL) {
if (*left_len < 2) {
return NULL;
}
left_len -= 2;
}
*p++ = (v >> 8) & 0xFF;
*p++ = (v >> 0) & 0xFF;
return (p);
}
static inline uint8_t *write_u32(uint8_t *p, const uint32_t v)
static inline uint8_t *write_u32(uint8_t *p, uint16_t *left_len, const uint32_t v)
{
if (left_len != NULL) {
if (*left_len < 4) {
return NULL;
}
left_len -= 4;
}
*p++ = (v >> 24) & 0xFF;
*p++ = (v >> 16) & 0xFF;
*p++ = (v >> 8) & 0xFF;
......@@ -64,11 +76,17 @@ static inline uint8_t *write_u32(uint8_t *p, const uint32_t v)
return (p);
}
static inline uint8_t *write_raw(uint8_t *p, const uint8_t *v)
static inline uint8_t *write_raw(uint8_t *p, uint16_t *left_len, const uint8_t *v)
{
uint32_t len;
len = strlen((const char *) v) + 1;
if (left_len != NULL) {
if (*left_len < len) {
return NULL;
}
left_len -= len;
}
memcpy(p, v, len);
p += len;
return (p);
......@@ -76,20 +94,26 @@ static inline uint8_t *write_raw(uint8_t *p, const uint8_t *v)
static inline const uint8_t *read_u16(const uint8_t *p, uint32_t *s, uint16_t *v)
{
if (*s < 2 || p == NULL) {
return NULL;
}
*v = 0;
*v |= *p++ << 8;
*v |= *p++ << 0;
*v |= (uint16_t) * p++ << 8;
*v |= (uint16_t) * p++ << 0;
*s -= 2;
return (p);
}
static inline const uint8_t *read_u32(const uint8_t *p, uint32_t *s, uint32_t *v)
{
if (*s < 4 || p == NULL) {
return NULL;
}
*v = 0;
*v |= *p++ << 24;
*v |= *p++ << 16;
*v |= *p++ << 8;
*v |= *p++ << 0;
*v |= (uint32_t) * p++ << 24;
*v |= (uint32_t) * p++ << 16;
*v |= (uint32_t) * p++ << 8;
*v |= (uint32_t) * p++ << 0;
*s -= 4;
return (p);
}
......
......@@ -16,12 +16,6 @@
typedef void *multicast_if;
static inline int os_wouldblock(void)
{
return (errno == EWOULDBLOCK);
}
static uint64_t hal_now_ms()
{
return aos_now_ms();
......@@ -55,9 +49,9 @@ struct mdns_ctx {
struct mdns_svc *services;
};
static int mdns_resolve(struct mdns_ctx *ctx, const char *addr, unsigned short port);
static uint32_t mdns_write_hdr(uint8_t *, const struct mdns_hdr *);
static int strrcmp(const char *, const char *);
static int32_t mdns_resolve(struct mdns_ctx *ctx, const char *addr, uint16_t port);
static int32_t mdns_write_hdr(uint8_t *buf, uint16_t *left, const struct mdns_hdr *hdr);
static int32_t strrcmp(const char *, const char *);
static uint32_t mdns_list_interfaces(multicast_if **pp_intfs, struct mdns_ip **pp_mdns_ips, uint32_t *p_nb_intf,
int ai_family)
......@@ -71,7 +65,8 @@ static uint32_t mdns_list_interfaces(multicast_if **pp_intfs, struct mdns_ip **p
memset(intfs, 0, sizeof(*intfs));
*pp_mdns_ips = mdns_ips = hal_malloc(sizeof(*mdns_ips));
if (mdns_ips == NULL) {
hal_free(mdns_ips);
hal_free(intfs);
*pp_intfs = NULL;
return (MDNS_ERROR);
}
memset(mdns_ips, 0, sizeof(*mdns_ips));
......@@ -79,9 +74,10 @@ static uint32_t mdns_list_interfaces(multicast_if **pp_intfs, struct mdns_ip **p
return (0);
}
static int mdns_resolve(struct mdns_ctx *ctx, const char *addr, unsigned short port)
static int32_t mdns_resolve(struct mdns_ctx *ctx, const char *addr, uint16_t port)
{
char buf[6];
int ret;
struct addrinfo hints, *res = NULL;
multicast_if *ifaddrs = NULL;
struct mdns_ip *mdns_ips = NULL;
......@@ -93,8 +89,8 @@ static int mdns_resolve(struct mdns_ctx *ctx, const char *addr, unsigned short p
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_DGRAM;
hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV;
errno = getaddrinfo(addr, buf, &hints, &res);
if (errno != 0) {
ret = getaddrinfo(addr, buf, &hints, &res);
if (ret != 0) {
return (MDNS_LKPERR);
}
......@@ -112,6 +108,7 @@ static int mdns_resolve(struct mdns_ctx *ctx, const char *addr, unsigned short p
ctx->conns = hal_malloc(ctx->nb_conns * sizeof(*ctx->conns));
if (ctx->conns == NULL) {
hal_free(ifaddrs);
hal_free(mdns_ips);
freeaddrinfo(res);
return (MDNS_ERROR);
}
......@@ -128,7 +125,7 @@ static int mdns_resolve(struct mdns_ctx *ctx, const char *addr, unsigned short p
}
int mdns_init(struct mdns_ctx **p_ctx, const char *addr, unsigned short port)
int mdns_init(struct mdns_ctx **p_ctx, const char *addr, uint16_t port)
{
int res;
......@@ -200,43 +197,73 @@ int mdns_destroy(struct mdns_ctx *ctx)
return (0);
}
static uint32_t mdns_write_hdr(uint8_t *ptr, const struct mdns_hdr *hdr)
static int32_t mdns_write_hdr(uint8_t *ptr, uint16_t *left, const struct mdns_hdr *hdr)
{
uint8_t *p = ptr;
p = write_u16(p, hdr->id);
p = write_u16(p, hdr->flags);
p = write_u16(p, hdr->num_qn);
p = write_u16(p, hdr->num_ans_rr);
p = write_u16(p, hdr->num_auth_rr);
p = write_u16(p, hdr->num_add_rr);
p = write_u16(p, left, hdr->id);
if (p == NULL) {
return (MDNS_ERROR);
}
p = write_u16(p, left, hdr->flags);
if (p == NULL) {
return (MDNS_ERROR);
}
p = write_u16(p, left, hdr->num_qn);
if (p == NULL) {
return (MDNS_ERROR);
}
p = write_u16(p, left, hdr->num_ans_rr);
if (p == NULL) {
return (MDNS_ERROR);
}
p = write_u16(p, left, hdr->num_auth_rr);
if (p == NULL) {
return (MDNS_ERROR);
}
p = write_u16(p, left, hdr->num_add_rr);
if (p == NULL) {
return (MDNS_ERROR);
}
return (p - ptr);
}
int mdns_send(const struct mdns_ctx *ctx, const struct mdns_hdr *hdr, const struct mdns_entry *entries)
{
//uint8_t buf[MDNS_PKT_MAXSZ] = {0};
const struct mdns_entry *entry = entries;
uint32_t n = 0, l, r;
uint8_t *buf;
uint16_t buf_len = MDNS_PKT_MAXSZ;
if (!entries) {
return (MDNS_ERROR);
}
buf = hal_malloc(MDNS_PKT_MAXSZ);
buf = hal_malloc(buf_len);
if (buf == NULL) {
return MDNS_STDERR;
}
memset(buf, 0, MDNS_PKT_MAXSZ);
l = mdns_write_hdr(buf, hdr);
memset(buf, 0, buf_len);
l = mdns_write_hdr(buf, &buf_len, hdr);
if (l < 0) {
hal_free(buf);
return (MDNS_ERROR);
}
n += l;
for (entry = entries; entry; entry = entry->next) {
l = mdns_write(buf + n, entry, (hdr->flags & FLAG_QR) > 0);
l = mdns_write(buf + n, &buf_len, entry, (hdr->flags & FLAG_QR) > 0);
if (l < 0) {
hal_free(buf);
return (MDNS_STDERR);
}
n += l;
if (n > MDNS_PKT_MAXSZ) {
log_e("mdns packet too large!give up");
hal_free(buf);
return (MDNS_STDERR);
}
}
for (uint32_t i = 0; i < ctx->nb_conns; ++i) {
r = lwip_sendto(ctx->conns[i].sock, (const char *) buf, n, 0,
......@@ -264,22 +291,22 @@ static void mdns_entries_free(struct mdns_entry *entries)
}
}
static const uint8_t *mdns_read_header(const uint8_t *ptr, uint32_t n, struct mdns_hdr *hdr)
static const uint8_t *mdns_read_header(const uint8_t *ptr, uint32_t *n, struct mdns_hdr *hdr)
{
if (n <= sizeof(struct mdns_hdr)) {
errno = ENOSPC;
//errno = ENOSPC;
return NULL;
}
ptr = read_u16(ptr, &n, &hdr->id);
ptr = read_u16(ptr, &n, &hdr->flags);
ptr = read_u16(ptr, &n, &hdr->num_qn);
ptr = read_u16(ptr, &n, &hdr->num_ans_rr);
ptr = read_u16(ptr, &n, &hdr->num_auth_rr);
ptr = read_u16(ptr, &n, &hdr->num_add_rr);
ptr = read_u16(ptr, n, &hdr->id);
ptr = read_u16(ptr, n, &hdr->flags);
ptr = read_u16(ptr, n, &hdr->num_qn);
ptr = read_u16(ptr, n, &hdr->num_ans_rr);
ptr = read_u16(ptr, n, &hdr->num_auth_rr);
ptr = read_u16(ptr, n, &hdr->num_add_rr);
return ptr;
}
static int mdns_recv(const struct mdns_conn *conn, struct mdns_hdr *hdr, struct mdns_entry **entries)
static int32_t mdns_recv(const struct mdns_conn *conn, struct mdns_hdr *hdr, struct mdns_entry **entries)
{
uint8_t *buf;
uint32_t num_entry, n;
......@@ -297,8 +324,13 @@ static int mdns_recv(const struct mdns_conn *conn, struct mdns_hdr *hdr, struct
hal_free(buf);
return (MDNS_NETERR);
}
const uint8_t *ptr = mdns_read_header(buf, length, hdr);
n = length;
n = (uint32_t)length;
const uint8_t *ptr = mdns_read_header(buf, &n, hdr);
if (ptr == NULL) {
hal_free(buf);
return (MDNS_NETERR);
}
num_entry = hdr->num_qn + hdr->num_ans_rr + hdr->num_add_rr;
for (uint32_t i = 0; i < num_entry; ++i) {
......@@ -306,12 +338,13 @@ static int mdns_recv(const struct mdns_conn *conn, struct mdns_hdr *hdr, struct
if (!entry) {
goto err;
}
memset(entry, 0, sizeof(struct mdns_svc));
memset(entry, 0, sizeof(struct mdns_entry));
ptr = mdns_read(ptr, &n, buf, entry, i >= hdr->num_qn);
if (!ptr) {
log_e("mdns_read err");
mdns_free(entry);
hal_free(entry);
errno = ENOSPC;
goto err;
}
......@@ -348,7 +381,7 @@ void mdns_print(const struct mdns_entry *entry)
}
static int strrcmp(const char *s1, const char *s2)
static int32_t strrcmp(const char *s1, const char *s2)
{
uint32_t m, n;
......@@ -363,19 +396,19 @@ static int strrcmp(const char *s1, const char *s2)
return (strncmp(s1 + m - n, s2, n));
}
static int mdns_listen_probe_network(const struct mdns_ctx *ctx, const char *const names[],
unsigned int nb_names, enum mdns_match_type match_type, mdns_listen_callback callback,
void *p_cookie)
static int32_t mdns_listen_probe_network(const struct mdns_ctx *ctx, const char *const names[],
uint32_t nb_names, enum mdns_match_type match_type, mdns_listen_callback callback,
void *p_cookie)
{
struct mdns_hdr ahdr = {0};
struct mdns_entry *entries;
struct mdns_svc *svc;
fd_set working_set;
struct timeval timeout;
int max_fd = 0;
int r;
for (uint32_t i = 0; i < ctx->nb_conns; ++i) {
int32_t max_fd = 0;
int32_t r;
int i;
for (i = 0; i < ctx->nb_conns; ++i) {
if (max_fd < ctx->conns[i].sock) {
max_fd = ctx->conns[i].sock;
}
......@@ -394,13 +427,13 @@ static int mdns_listen_probe_network(const struct mdns_ctx *ctx, const char *con
return r;
}
for (uint32_t i = 0; i < ctx->nb_conns; ++i) {
for (i = 0; i < ctx->nb_conns; ++i) {
if (FD_ISSET(ctx->conns[i].sock, &working_set) == 0) {
continue;
}
r = mdns_recv(&ctx->conns[i], &ahdr, &entries);
if (r == MDNS_NETERR && os_wouldblock()) {
if (r < 0) {
log_e("---MDNS_NETERR -----");
mdns_entries_free(entries);
continue;
......@@ -414,7 +447,7 @@ static int mdns_listen_probe_network(const struct mdns_ctx *ctx, const char *con
if (match_type == MDNS_MATCH_ALL) {
callback(p_cookie, r, entries);
} else {
for (unsigned int i = 0; i < nb_names; ++i) {
for (i = 0; i < nb_names; ++i) {
for (struct mdns_entry *entry = entries; entry; entry = entry->next) {
if (!strrcmp(entry->name, names[i])) {
callback(p_cookie, r, entries);
......@@ -456,7 +489,7 @@ int mdns_announce(struct mdns_ctx *ctx, const char *service, enum mdns_type type
}
int mdns_start(const struct mdns_ctx *ctx, const char *const names[],
unsigned int nb_names, enum mdns_type type, unsigned int interval, enum mdns_match_type match_type,
uint32_t nb_names, enum mdns_type type, uint32_t interval, enum mdns_match_type match_type,
mdns_stop_func stop, mdns_listen_callback callback, void *p_cookie)
{
if (ctx->nb_conns == 0) {
......
......@@ -7,11 +7,11 @@
typedef const uint8_t *(*mdns_reader)(const uint8_t *, uint32_t *, const uint8_t *, struct mdns_entry *);
typedef uint32_t (*mdns_writer)(uint8_t *, const struct mdns_entry *);
typedef int32_t (*mdns_writer)(uint8_t *, uint16_t *, const struct mdns_entry *);
typedef void (*mdns_printer)(const union mdns_data *);
static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t *root, char **ss);
static uint8_t *mdns_encode(char *s);
static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t *root, char **ss, uint8_t nb_times);
static uint8_t *mdns_encode(const char *s);
const uint8_t *mdns_read(const uint8_t *ptr, uint32_t *n, const uint8_t *root, struct mdns_entry *entry, int8_t ans);
static const uint8_t *mdns_read_SRV(const uint8_t *, uint32_t *, const uint8_t *, struct mdns_entry *);
......@@ -20,12 +20,11 @@ static const uint8_t *mdns_read_TXT(const uint8_t *, uint32_t *, const uint8_t *
static const uint8_t *mdns_read_AAAA(const uint8_t *, uint32_t *, const uint8_t *, struct mdns_entry *);
static const uint8_t *mdns_read_A(const uint8_t *, uint32_t *, const uint8_t *, struct mdns_entry *);
uint32_t mdns_write(uint8_t *ptr, const struct mdns_entry *entry, int8_t ans);
static uint32_t mdns_write_SRV(uint8_t *, const struct mdns_entry *);
static uint32_t mdns_write_PTR(uint8_t *, const struct mdns_entry *);
static uint32_t mdns_write_TXT(uint8_t *, const struct mdns_entry *);
static uint32_t mdns_write_AAAA(uint8_t *, const struct mdns_entry *);
static uint32_t mdns_write_A(uint8_t *, const struct mdns_entry *);
static int32_t mdns_write_SRV(uint8_t *, uint16_t *, const struct mdns_entry *);
static int32_t mdns_write_PTR(uint8_t *, uint16_t *, const struct mdns_entry *);
static int32_t mdns_write_TXT(uint8_t *, uint16_t *, const struct mdns_entry *);
static int32_t mdns_write_AAAA(uint8_t *, uint16_t *, const struct mdns_entry *);
static int32_t mdns_write_A(uint8_t *, uint16_t *, const struct mdns_entry *);
void mdns_print(const struct mdns_entry *entry);
static void mdns_print_SRV(const union mdns_data *);
......@@ -72,24 +71,40 @@ static const uint8_t *mdns_read_SRV(const uint8_t *ptr, uint32_t *n, const uint8
ptr = read_u16(ptr, n, &data->SRV.priority);
ptr = read_u16(ptr, n, &data->SRV.weight);
ptr = read_u16(ptr, n, &data->SRV.port);
if ((ptr = mdns_decode(ptr, n, root, &data->SRV.target)) == NULL) {
if ((ptr = mdns_decode(ptr, n, root, &data->SRV.target, 0)) == NULL) {
return (NULL);
}
return (ptr);
}
static uint32_t mdns_write_SRV(uint8_t *ptr, const struct mdns_entry *entry)
static int32_t mdns_write_SRV(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry)
{
uint8_t *target, *p = ptr;
if ((target = mdns_encode(entry->data.SRV.target)) == NULL) {
return (0);
return (MDNS_ERROR);
}
p = write_u16(p, entry->data.SRV.priority);
p = write_u16(p, entry->data.SRV.weight);
p = write_u16(p, entry->data.SRV.port);
p = write_raw(p, target);
p = write_u16(p, left, entry->data.SRV.priority);
if (p == NULL) {
hal_free(target);
return (MDNS_ERROR);
}
p = write_u16(p, left, entry->data.SRV.weight);
if (p == NULL) {
hal_free(target);
return (MDNS_ERROR);
}
p = write_u16(p, left, entry->data.SRV.port);
if (p == NULL) {
hal_free(target);
return (MDNS_ERROR);
}
p = write_raw(p, left, target);
if (p == NULL) {
hal_free(target);
return (MDNS_ERROR);
}
hal_free(target);
return (p - ptr);
}
......@@ -112,20 +127,24 @@ static const uint8_t *mdns_read_PTR(const uint8_t *ptr, uint32_t *n, const uint8
return (NULL);
}
if ((ptr = mdns_decode(ptr, n, root, &data->PTR.domain)) == NULL) {
if ((ptr = mdns_decode(ptr, n, root, &data->PTR.domain, 0)) == NULL) {
log_e("mdns_decode failed");
return (NULL);
}
return (ptr);
}
static uint32_t mdns_write_PTR(uint8_t *ptr, const struct mdns_entry *entry)
static int32_t mdns_write_PTR(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry)
{
uint8_t *domain, *p = ptr;
if ((domain = mdns_encode(entry->data.PTR.domain)) == NULL) {
return (0);
return (MDNS_ERROR);
}
p = write_raw(p, left, domain);
if (p == NULL) {
hal_free(domain);
return (MDNS_ERROR);
}
p = write_raw(p, domain);
hal_free(domain);
return (p - ptr);
}
......@@ -168,17 +187,21 @@ static const uint8_t *mdns_read_TXT(const uint8_t *ptr, uint32_t *n, const uint8
return (ptr);
}
static uint32_t mdns_write_TXT(uint8_t *ptr, const struct mdns_entry *entry)
static int32_t mdns_write_TXT(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry)
{
uint8_t *p = ptr;
uint8_t l;
uint16_t l;
struct mdns_data_txt *text = entry->data.TXT;
while (text) {
l = strlen(text->txt);
if (*left < l + 1) {
return MDNS_ERROR;
}
memcpy(p, &l, 1);
memcpy(p + 1, text->txt, l);
p += l + 1;
*left -= l + 1;
text = text->next;
}
return (p - ptr);
......@@ -213,9 +236,13 @@ static const uint8_t *mdns_read_AAAA(const uint8_t *ptr, uint32_t *n, const uint
return (ptr);
}
static uint32_t mdns_write_AAAA(uint8_t *ptr, const struct mdns_entry *entry)
static int32_t mdns_write_AAAA(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry)
{
uint32_t len = sizeof(entry->data.AAAA.addr);
if (*left < len) {
return MDNS_ERROR;
}
*left -= len;
memcpy(ptr, &entry->data.AAAA.addr, len);
return len;
}
......@@ -242,9 +269,13 @@ static const uint8_t *mdns_read_A(const uint8_t *ptr, uint32_t *n, const uint8_t
return (ptr);
}
static uint32_t mdns_write_A(uint8_t *ptr, const struct mdns_entry *entry)
static int32_t mdns_write_A(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry)
{
uint32_t len = sizeof(entry->data.A.addr);
if (*left < len) {
return MDNS_ERROR;
}
*left -= len;
memcpy(ptr, &entry->data.A.addr, sizeof(entry->data.A.addr));
return len;
}
......@@ -258,10 +289,16 @@ static void mdns_print_A(const union mdns_data *data)
* Decodes a DN compressed format (RFC 1035)
* e.g "\x03foo\x03bar\x00" gives "foo.bar"
*/
static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t *root, char **ss)
static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t *root, char **ss, uint8_t nb_times)
{
char *s;
char *s = NULL;
const uint8_t *orig_ptr = ptr;
if (*n == 0) {
return (NULL);
}
if (nb_times > 16) {
return (NULL);
}
s = *ss = hal_malloc(MDNS_DN_MAXSZ);
if (!s) {
return (NULL);
......@@ -284,7 +321,7 @@ static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t
/* resolve the offset of the pointer (RFC 1035-4.1.4) */
if ((len & 0xC0) == 0xC0) {
const uint8_t *p;
char *buf;
char *buf = NULL;
uint32_t m;
if (*n < sizeof(len)) {
......@@ -296,8 +333,19 @@ static const uint8_t *mdns_decode(const uint8_t *ptr, uint32_t *n, const uint8_t
advance(1);
p = root + len;
if (p > (ptr - 2)) {
log_e("mdns_decode err, buf too short");
goto err;
}
m = ptr - p + *n;
mdns_decode(p, &m, root, &buf);
/* avoid recursing on the same element */
if (p == orig_ptr) {
goto err;
}
if (mdns_decode(p, &m, root, &buf, nb_times + 1) == NULL) {
log_e("mdns_decode err");
goto err;
}
if (free_space <= strlen(buf)) {
hal_free(buf);
log_e("free_space <= strlen(buf)");
......@@ -328,11 +376,13 @@ err:
* Encode a DN into its compressed format (RFC 1035)
* e.g "foo.bar" gives "\x03foo\x03bar\x00"
*/
static uint8_t *mdns_encode(char *s)
static uint8_t *mdns_encode(const char *s)
{
uint8_t *buf, *b, l = 0;
char *p = s;
if (s == NULL) {
return NULL;
}
buf = hal_malloc(strlen(s) + 2);
if (!buf) {
return (NULL);
......@@ -353,7 +403,7 @@ static const uint8_t *mdns_read_RR(const uint8_t *ptr, uint32_t *n, const uint8_
{
uint16_t tmp;
ptr = mdns_decode(ptr, n, root, &entry->name);
ptr = mdns_decode(ptr, n, root, &entry->name, 0);
if (!ptr || *n < 4) {
return (NULL);
}
......@@ -374,21 +424,40 @@ static const uint8_t *mdns_read_RR(const uint8_t *ptr, uint32_t *n, const uint8_
return ptr;
}
static uint32_t mdns_write_RR(uint8_t *ptr, const struct mdns_entry *entry, int8_t ans)
static int32_t mdns_write_RR(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry, int8_t ans)
{
uint8_t *name, *p = ptr;
if ((name = mdns_encode(entry->name)) == NULL) {
return (0);
return MDNS_ERROR;
}
p = write_raw(p, name);
p = write_u16(p, entry->type);
p = write_u16(p, (entry->class & ~0x8000) | (entry->msbit << 15));
p = write_raw(p, left, name);
if (p == NULL) {
hal_free(name);
return MDNS_ERROR;
}
p = write_u16(p, left, entry->type);
if (p == NULL) {
hal_free(name);
return MDNS_ERROR;
}
p = write_u16(p, left, (entry->class & ~0x8000) | (entry->msbit << 15));
if (p == NULL) {
hal_free(name);
return MDNS_ERROR;
}
if (ans) {
p = write_u32(p, entry->ttl);
p = write_u16(p, entry->data_len);
p = write_u32(p, left, entry->ttl);
if (p == NULL) {
hal_free(name);
return MDNS_ERROR;
}
p = write_u16(p, left, entry->data_len);
if (p == NULL) {
hal_free(name);
return MDNS_ERROR;
}
}
hal_free(name);
return (p - ptr);
......@@ -403,6 +472,9 @@ const uint8_t *mdns_read(const uint8_t *ptr, uint32_t *n, const uint8_t *root, s
if (ans == 0) {
return ptr;
}
if (ptr == NULL) {
return NULL;
}
for (uint32_t i = 0; i < mdns_num; ++i) {
if (rrs[i].type == entry->type) {
......@@ -423,13 +495,16 @@ const uint8_t *mdns_read(const uint8_t *ptr, uint32_t *n, const uint8_t *root, s
return (ptr);
}
uint32_t mdns_write(uint8_t *ptr, const struct mdns_entry *entry, int8_t ans)
int32_t mdns_write(uint8_t *ptr, uint16_t *left, const struct mdns_entry *entry, int8_t ans)
{
uint8_t *p = ptr;
uint32_t n = 0;
uint16_t l = 0;
int32_t l = 0;
l = mdns_write_RR(p, entry, ans);
l = mdns_write_RR(p, left, entry, ans);
if (l < 0) {
MDNS_ERROR;
}
n += l;
if (ans == 0) {
......@@ -438,9 +513,15 @@ uint32_t mdns_write(uint8_t *ptr, const struct mdns_entry *entry, int8_t ans)
for (uint32_t i = 0; i < mdns_num; ++i) {
if (rrs[i].type == entry->type) {
l = (*rrs[i].write)(p + n, entry);
l = (*rrs[i].write)(p + n, left, entry);
if (l < 0) {
MDNS_ERROR;
}
if (l == 0) {
continue;
}
// fill in data length after its computed
write_u16(p + n - 2, l);
write_u16(p + n - 2, NULL, l);
n += l;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册