type.ml 35.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
(* rust/src/boot/me/type.ml *)

(* An ltype is the type of a segment of an lvalue. It is used only by
 * [check_lval] and friends and are the closest Rust ever comes to polymorphic
 * types. All ltypes must be resolved into monotypes by the time an outer
 * lvalue (i.e. an lvalue whose parent is not also an lvalue) is finished
 * typechecking. *)

type ltype =
    LTYPE_mono of Ast.ty
  | LTYPE_poly of Ast.ty_param array * Ast.ty   (* "big lambda" *)
  | LTYPE_module of Ast.mod_items               (* type of a module *)

type fn_ctx = {
  fnctx_return_type: Ast.ty;
16 17
  fnctx_is_iter: bool;
  mutable fnctx_just_saw_ret: bool
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
}

exception Type_error of string * Ast.ty

let log cx =
  Session.log
    "type"
    cx.Semant.ctxt_sess.Session.sess_log_type
    cx.Semant.ctxt_sess.Session.sess_log_out

let type_error expected actual = raise (Type_error (expected, actual))

(* We explicitly curry [cx] like this to avoid threading it through all the
 * inner functions. *)
let check_stmt (cx:Semant.ctxt) : (fn_ctx -> Ast.stmt -> unit) =
  (* Returns the part of the type that matters for typechecking. *)
  let rec fundamental_ty (ty:Ast.ty) : Ast.ty =
    match ty with
        Ast.TY_constrained (ty', _) | Ast.TY_mutable ty' -> fundamental_ty ty'
      | _ -> ty
G
Graydon Hoare 已提交
38 39
  in

40 41 42 43
  let sprintf_ltype _ (lty:ltype) : string =
    match lty with
        LTYPE_mono ty | LTYPE_poly (_, ty) -> Ast.sprintf_ty () ty
      | LTYPE_module items -> Ast.sprintf_mod_items () items
G
Graydon Hoare 已提交
44
  in
45

46 47
  let get_slot_ty (slot:Ast.slot) : Ast.ty =
    match slot.Ast.slot_ty with
48
        Some ty -> ty
49 50
      | None -> Common.bug () "get_slot_ty: no type in slot"
  in
51

52 53 54 55 56 57
  (* [unbox ty] strips off all boxes in [ty] and returns a tuple consisting of
   * the number of boxes that were stripped off. *)
  let unbox (ty:Ast.ty) : (Ast.ty * int) =
    let rec unbox ty acc =
      match ty with
          Ast.TY_box ty' -> unbox ty' (acc + 1)
58
        | Ast.TY_mutable ty' | Ast.TY_constrained (ty', _) -> unbox ty' acc
59 60 61
        | _ -> (ty, acc)
    in
    unbox ty 0
62 63
  in

64
  let maybe_mutable (mutability:Ast.mutability) (ty:Ast.ty) : Ast.ty =
65 66 67 68 69
    let res =
      if mutability = Ast.MUT_mutable then Ast.TY_mutable ty else ty
    in
      log cx "maybe_mutable: %a -> %a" Ast.sprintf_ty ty Ast.sprintf_ty res;
      res
70 71
  in

72 73 74 75 76 77 78 79 80 81 82 83
  (*
   * Type assertions
   *)

  let is_integer (ty:Ast.ty) =
    match ty with
        Ast.TY_int | Ast.TY_uint
      | Ast.TY_mach Common.TY_u8 | Ast.TY_mach Common.TY_u16
      | Ast.TY_mach Common.TY_u32 | Ast.TY_mach Common.TY_u64
      | Ast.TY_mach Common.TY_i8 | Ast.TY_mach Common.TY_i16
      | Ast.TY_mach Common.TY_i32 | Ast.TY_mach Common.TY_i64 -> true
      | _ -> false
84 85
  in

86 87 88 89
  let demand (expected:Ast.ty) (actual:Ast.ty) : unit =
    let expected, actual = fundamental_ty expected, fundamental_ty actual in
    if expected <> actual then
      type_error (Printf.sprintf "%a" Ast.sprintf_ty expected) actual
G
Graydon Hoare 已提交
90
  in
91 92 93
  let demand_integer (actual:Ast.ty) : unit =
    if not (is_integer (fundamental_ty actual)) then
      type_error "integer" actual
G
Graydon Hoare 已提交
94
  in
95 96 97 98 99
  let demand_bool_or_char_or_integer (actual:Ast.ty) : unit =
    match fundamental_ty actual with
        Ast.TY_bool | Ast.TY_char -> ()
      | ty when is_integer ty -> ()
      | _ -> type_error "bool, char, or integer" actual
G
Graydon Hoare 已提交
100
  in
101 102 103 104
  let demand_number (actual:Ast.ty) : unit =
    match fundamental_ty actual with
        Ast.TY_int | Ast.TY_uint | Ast.TY_mach _ -> ()
      | ty -> type_error "number" ty
105
  in
106 107 108 109 110 111
  let demand_number_or_str_or_vector (actual:Ast.ty) : unit =
    match fundamental_ty actual with
        Ast.TY_int | Ast.TY_uint | Ast.TY_mach _ | Ast.TY_str
      | Ast.TY_vec _ ->
          ()
      | ty -> type_error "number, string, or vector" ty
112
  in
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
  let demand_vec (actual:Ast.ty) : Ast.ty =
    match fundamental_ty actual with
        Ast.TY_vec ty -> ty
      | ty -> type_error "vector" ty
  in
  let demand_vec_with_mutability
      (mut:Ast.mutability)
      (actual:Ast.ty)
      : Ast.ty =
    match mut, fundamental_ty actual with
        Ast.MUT_mutable, Ast.TY_vec ((Ast.TY_mutable _) as ty) -> ty
      | Ast.MUT_mutable, ty -> type_error "mutable vector" ty
      | Ast.MUT_immutable, ((Ast.TY_vec (Ast.TY_mutable _)) as ty) ->
          type_error "immutable vector" ty
      | Ast.MUT_immutable, Ast.TY_vec ty -> ty
      | Ast.MUT_immutable, ty -> type_error "immutable vector" ty
  in
  let demand_vec_or_str (actual:Ast.ty) : Ast.ty =
    match fundamental_ty actual with
        Ast.TY_vec ty -> ty
      | Ast.TY_str -> Ast.TY_mach Common.TY_u8
      | ty -> type_error "vector or string" ty
  in
  let demand_rec (actual:Ast.ty) : Ast.ty_rec =
    match fundamental_ty actual with
        Ast.TY_rec ty_rec -> ty_rec
      | ty -> type_error "record" ty
  in
  let demand_fn (arg_tys:Ast.ty option array) (actual:Ast.ty) : Ast.ty =
    let expected = lazy begin
      Format.fprintf Format.str_formatter "fn(";
      let print_arg_ty i arg_ty_opt =
        if i > 0 then Format.fprintf Format.str_formatter ", ";
        match arg_ty_opt with
            None -> Format.fprintf Format.str_formatter "<?>"
          | Some arg_ty -> Ast.fmt_ty Format.str_formatter arg_ty
      in
      Array.iteri print_arg_ty arg_tys;
      Format.fprintf Format.str_formatter ")";
      Format.flush_str_formatter()
    end in
    match fundamental_ty actual with
        Ast.TY_fn (ty_sig, _) as ty ->
          let in_slots = ty_sig.Ast.sig_input_slots in
          if Array.length arg_tys != Array.length in_slots then
            type_error (Lazy.force expected) ty;
          let in_slot_tys = Array.map get_slot_ty in_slots in
          let maybe_demand a_opt b =
            match a_opt with None -> () | Some a -> demand a b
          in
          Common.arr_iter2 maybe_demand arg_tys in_slot_tys;
          get_slot_ty (ty_sig.Ast.sig_output_slot)
      | ty -> type_error "function" ty
  in
  let demand_chan (actual:Ast.ty) : Ast.ty =
    match fundamental_ty actual with
        Ast.TY_chan ty -> ty
      | ty -> type_error "channel" ty
  in
  let demand_port (actual:Ast.ty) : Ast.ty =
    match fundamental_ty actual with
        Ast.TY_port ty -> ty
      | ty -> type_error "port" ty
  in
  let demand_all (tys:Ast.ty array) : Ast.ty =
    if Array.length tys == 0 then
      Common.bug () "demand_all called with an empty array of types";
    let pivot = fundamental_ty tys.(0) in
    for i = 1 to Array.length tys - 1 do
      demand pivot tys.(i)
    done;
    pivot
185 186
  in

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
  (* Performs beta-reduction (that is, type-param substitution). *)
  let beta_reduce
      (lval_id:Common.node_id)
      (lty:ltype)
      (args:Ast.ty array)
      : ltype =
    if Hashtbl.mem cx.Semant.ctxt_call_lval_params lval_id then
      assert (args = Hashtbl.find cx.Semant.ctxt_call_lval_params lval_id)
    else
      Hashtbl.add cx.Semant.ctxt_call_lval_params lval_id args;

    match lty with
        LTYPE_poly (params, ty) ->
          LTYPE_mono (Semant.rebuild_ty_under_params ty params args true)
      | _ ->
        Common.err None "expected polymorphic type but found %a"
          sprintf_ltype lty
  in
G
Graydon Hoare 已提交
205

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  (*
   * Lvalue and slot checking.
   *
   * We use a slightly different type language here which includes polytypes;
   * see [ltype] above.
   *)

  let check_literal (lit:Ast.lit) : Ast.ty =
    match lit with
        Ast.LIT_nil -> Ast.TY_nil
      | Ast.LIT_bool _ -> Ast.TY_bool
      | Ast.LIT_mach (mty, _, _) -> Ast.TY_mach mty
      | Ast.LIT_int _ -> Ast.TY_int
      | Ast.LIT_uint _ -> Ast.TY_uint
      | Ast.LIT_char _ -> Ast.TY_char
  in
G
Graydon Hoare 已提交
222

223 224 225 226 227 228 229 230 231 232 233 234 235
  (* Here the actual inference happens. *)
  let internal_check_slot
      (infer:Ast.ty option)
      (defn_id:Common.node_id)
      : Ast.ty =
    let slot =
      match Hashtbl.find cx.Semant.ctxt_all_defns defn_id with
          Semant.DEFN_slot slot -> slot
        | _ ->
          Common.bug
            ()
            "internal_check_slot: supplied defn wasn't a slot at all"
    in
236
    match infer, slot.Ast.slot_ty with
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        Some expected, Some actual ->
          demand expected actual;
          actual
      | Some inferred, None ->
          log cx "setting auto slot #%d = %a to type %a"
            (Common.int_of_node defn_id)
            Ast.sprintf_slot_key
              (Hashtbl.find cx.Semant.ctxt_slot_keys defn_id)
            Ast.sprintf_ty inferred;
          let new_slot = { slot with Ast.slot_ty = Some inferred } in
          Hashtbl.replace cx.Semant.ctxt_all_defns defn_id
            (Semant.DEFN_slot new_slot);
          inferred
      | None, Some actual -> actual
      | None, None ->
          Common.err None "can't infer any type for this slot"
  in
G
Graydon Hoare 已提交
254

255 256 257 258 259 260 261 262 263 264 265 266 267 268
  let internal_check_mod_item_decl
      (mid:Ast.mod_item_decl)
      (mid_id:Common.node_id)
      : ltype =
    match mid.Ast.decl_item with
        Ast.MOD_ITEM_mod (_, items) -> LTYPE_module items
      | Ast.MOD_ITEM_fn _ | Ast.MOD_ITEM_obj _ | Ast.MOD_ITEM_tag _ ->
          let ty = Hashtbl.find cx.Semant.ctxt_all_item_types mid_id in
          let params = mid.Ast.decl_params in
          if Array.length params == 0 then
            LTYPE_mono ty
          else
            LTYPE_poly ((Array.map (fun p -> p.Common.node) params), ty)
      | Ast.MOD_ITEM_type _ ->
269
          Common.err None "Type-item used in non-type context"
270
  in
G
Graydon Hoare 已提交
271

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
  let rec internal_check_base_lval
      (infer:Ast.ty option)
      (nbi:Ast.name_base Common.identified)
      : ltype =
    let lval_id = nbi.Common.id in
    let referent = Semant.lval_to_referent cx lval_id in
    let lty =
      match Hashtbl.find cx.Semant.ctxt_all_defns referent with
          Semant.DEFN_slot _ ->
            LTYPE_mono (internal_check_slot infer referent)
        | Semant.DEFN_item mid -> internal_check_mod_item_decl mid referent
        | _ -> Common.bug () "internal_check_base_lval: unexpected defn type"
    in
    match nbi.Common.node with
        Ast.BASE_ident _ | Ast.BASE_temp _ -> lty
      | Ast.BASE_app (_, args) -> beta_reduce lval_id lty args

  and internal_check_ext_lval
      (base:Ast.lval)
      (comp:Ast.lval_component)
      : ltype =
    let base_ity =
      match internal_check_lval None base with
          LTYPE_poly (_, ty) ->
            Common.err None "can't index the polymorphic type '%a'"
              Ast.sprintf_ty ty
        | LTYPE_mono ty -> `Type (fundamental_ty ty)
        | LTYPE_module items -> `Module items
    in
G
Graydon Hoare 已提交
301

302 303 304 305 306
    let sprintf_itype chan () =
      match base_ity with
          `Type ty -> Ast.sprintf_ty chan ty
        | `Module items -> Ast.sprintf_mod_items chan items
    in
G
Graydon Hoare 已提交
307

308 309 310 311
    let _ = log cx "base lval %a, base type %a"
      Ast.sprintf_lval base sprintf_itype ()
    in

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    let rec typecheck base_ity =
      match base_ity, comp with
          `Type (Ast.TY_rec ty_rec), Ast.COMP_named (Ast.COMP_ident id) ->
            let find _ (k, v) = if id = k then Some v else None in
            let comp_ty =
              match Common.arr_search ty_rec find with
                  Some elem_ty -> elem_ty
                | None ->
                    Common.err
                      None
                      "field '%s' is not one of the fields of '%a'"
                      id
                      sprintf_itype ()
            in
            LTYPE_mono comp_ty

        | `Type (Ast.TY_rec _), _ ->
            Common.err None "the record type '%a' must be indexed by name"
              sprintf_itype ()

        | `Type (Ast.TY_obj ty_obj), Ast.COMP_named (Ast.COMP_ident id) ->
            let comp_ty =
              try
                Ast.TY_fn (Hashtbl.find (snd ty_obj) id)
              with Not_found ->
                Common.err
                  None
                  "method '%s' is not one of the methods of '%a'"
                  id
                  sprintf_itype ()
            in
            LTYPE_mono comp_ty

        | `Type (Ast.TY_obj _), _ ->
            Common.err
              None
              "the object type '%a' must be indexed by name"
              sprintf_itype ()

        | `Type (Ast.TY_tup ty_tup), Ast.COMP_named (Ast.COMP_idx idx)
              when idx < Array.length ty_tup ->
            LTYPE_mono (ty_tup.(idx))

        | `Type (Ast.TY_tup _), Ast.COMP_named (Ast.COMP_idx idx) ->
            Common.err
              None
              "member '_%d' is not one of the members of '%a'"
              idx
              sprintf_itype ()

        | `Type (Ast.TY_tup _), _ ->
            Common.err
              None
              "the tuple type '%a' must be indexed by tuple index"
              sprintf_itype ()

        | `Type (Ast.TY_vec ty_vec), Ast.COMP_atom atom ->
            demand Ast.TY_int (check_atom atom);
            LTYPE_mono ty_vec

        | `Type (Ast.TY_vec _), _ ->
            Common.err None "the vector type '%a' must be indexed via an int"
              sprintf_itype ()

        | `Type Ast.TY_str, Ast.COMP_atom atom ->
            demand Ast.TY_int (check_atom atom);
            LTYPE_mono (Ast.TY_mach Common.TY_u8)

        | `Type Ast.TY_str, _ ->
            Common.err None "strings must be indexed via an int"

        | `Type (Ast.TY_box ty_box), Ast.COMP_deref -> LTYPE_mono ty_box

        | `Type (Ast.TY_box ty_box), _ ->
            typecheck (`Type ty_box)  (* automatically dereference! *)

        | `Type ty, Ast.COMP_named (Ast.COMP_ident _) ->
            Common.err None "the type '%a' can't be indexed by name"
              Ast.sprintf_ty ty

        | `Type ty, Ast.COMP_named (Ast.COMP_app _) ->
            Common.err
              None
              "the type '%a' has no type parameters, so it can't be applied \
              to types"
              Ast.sprintf_ty ty

        | `Module items, Ast.COMP_named ((Ast.COMP_ident id) as name_comp)
400 401
        | `Module items, Ast.COMP_named ((Ast.COMP_app (id, _))
                                           as name_comp) ->
402 403 404 405 406 407
            let mod_item =
              try
                Hashtbl.find items id
              with Not_found ->
                Common.bug
                  ()
408 409
                  "internal_check_ext_lval: ident %s not found in mod item"
                  id
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
            in
            let lty =
              internal_check_mod_item_decl
                mod_item.Common.node
                mod_item.Common.id
            in
            begin
              match name_comp with
                  Ast.COMP_ident _ -> lty
                | Ast.COMP_app (_, args) ->
                    beta_reduce (Semant.lval_base_id base) lty args
                | _ ->
                  Common.bug ()
                    "internal_check_ext_lval: unexpected name_comp"
            end
425

426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
        | _, Ast.COMP_named (Ast.COMP_idx _) ->
            Common.err
              None
              "%a isn't a tuple, so it can't be indexed by tuple index"
              sprintf_itype ()

        | _, Ast.COMP_atom atom ->
            Common.err
              None
              "%a can't by indexed by the type '%a'"
              sprintf_itype ()
              Ast.sprintf_ty (check_atom atom)

        | _, Ast.COMP_deref ->
            Common.err
              None
              "%a isn't a box and can't be dereferenced"
              sprintf_itype ()
    in
    typecheck base_ity

  and internal_check_lval (infer:Ast.ty option) (lval:Ast.lval) : ltype =
    match lval with
        Ast.LVAL_base nbi -> internal_check_base_lval infer nbi
      | Ast.LVAL_ext (base, comp) -> internal_check_ext_lval base comp

  (* Checks the outermost lvalue and returns the resulting monotype and the
   * number of layers of indirection needed to access it (i.e. the number of
   * boxes that were automatically dereferenced, which will always be zero if
   * the supplied [autoderef] parameter is false). This function is the bridge
   * between the polymorphically typed lvalue world and the monomorphically
   * typed value world. *)
  and internal_check_outer_lval
      ~mut:(mut:Ast.mutability)
      ~deref:(deref:bool)
      (infer:Ast.ty option)
      (lval:Ast.lval)
      : (Ast.ty * int) =
    let yield_ty ty =
      let (ty, n_boxes) = if deref then unbox ty else (ty, 0) in
466
        (maybe_mutable mut ty, n_boxes)
467 468 469 470 471 472
    in
    match infer, internal_check_lval infer lval with
      | None, LTYPE_mono ty -> yield_ty ty
      | Some expected, LTYPE_mono actual ->
          demand expected actual;
          yield_ty actual
473
      | None, (LTYPE_poly _ as lty) ->
474 475 476 477 478 479 480
          Common.err
            None
            "not enough context to automatically instantiate the polymorphic \
              type '%a'; supply type parameters explicitly"
            sprintf_ltype lty
      | Some _, (LTYPE_poly _) ->
          (* FIXME: auto-instantiate *)
481
          Common.unimpl
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
            None
            "sorry, automatic polymorphic instantiation isn't supported yet; \
              please supply type parameters explicitly"
      | _, LTYPE_module _ ->
          Common.err None "can't refer to a module as a first-class value"

  and generic_check_lval
      ~mut:(mut:Ast.mutability)
      ~deref:(deref:bool)
      (infer:Ast.ty option)
      (lval:Ast.lval)
      : Ast.ty =
    (* The lval we got is an impostor (it may contain unresolved TY_nameds).
     * Get the real one. *)
    let lval_id = Semant.lval_base_id lval in
    let lval = Hashtbl.find cx.Semant.ctxt_all_lvals lval_id in
498 499 500 501 502 503 504 505
    let _ = log cx "generic_check_lval %a mut=%s deref=%s infer=%s"
      Ast.sprintf_lval lval
      (if mut = Ast.MUT_mutable then "mutable" else "immutable")
      (if deref then "true" else "false")
      (match infer with
           None -> "<none>"
         | Some t -> Fmt.fmt_to_str Ast.fmt_ty t)
    in
506 507 508
    let (lval_ty, n_boxes) =
      internal_check_outer_lval ~mut:mut ~deref:deref infer lval
    in
509 510 511 512
    let _ = log cx "checked lval %a with type %a"
      Ast.sprintf_lval lval
      Ast.sprintf_ty lval_ty
    in
G
Graydon Hoare 已提交
513

514 515 516 517
    if Hashtbl.mem cx.Semant.ctxt_all_lval_types lval_id then
      assert ((Hashtbl.find cx.Semant.ctxt_all_lval_types lval_id) = lval_ty)
    else
      Hashtbl.replace cx.Semant.ctxt_all_lval_types lval_id lval_ty;
G
Graydon Hoare 已提交
518

519 520 521
    if Hashtbl.mem cx.Semant.ctxt_auto_deref_lval lval_id then begin
      let prev_autoderef =
        Hashtbl.find cx.Semant.ctxt_auto_deref_lval lval_id
G
Graydon Hoare 已提交
522
      in
523 524 525 526 527 528 529 530 531
      if n_boxes == 0 && prev_autoderef then
        Common.bug () "generic_check_lval: lval was previously marked as \
          deref but isn't any longer";
      if n_boxes > 0 && not prev_autoderef then
        Common.bug () "generic_check_lval: lval wasn't marked as autoderef \
          before, but now it is"
    end;
    if n_boxes > 1 then
      (* TODO: allow more than one level of automatic dereference *)
532 533
      Common.unimpl (Some (Semant.lval_base_id lval))
        "sorry, only one level of automatic dereference is \
534 535 536
        implemented; please add explicit dereference operators";
    Hashtbl.replace cx.Semant.ctxt_auto_deref_lval lval_id (n_boxes > 0);

537 538
      (* Before demoting the lval to a value, strip off mutability. *)
      fundamental_ty lval_ty
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554

  (* Note that this function should be avoided when possible, because it
   * cannot perform type inference. In general you should prefer
   * [infer_lval]. *)
  and check_lval
      ?mut:(mut=Ast.MUT_immutable)
      ?deref:(deref=false)
      (lval:Ast.lval)
      : Ast.ty =
    generic_check_lval ~mut:mut ~deref:deref None lval

  and check_atom ?deref:(deref=false) (atom:Ast.atom) : Ast.ty =
    match atom with
        Ast.ATOM_lval lval -> check_lval ~deref:deref lval
      | Ast.ATOM_literal lit_id -> check_literal lit_id.Common.node
  in
G
Graydon Hoare 已提交
555

556 557 558
  let infer_slot (ty:Ast.ty) (slot_id:Common.node_id) : unit =
    ignore (internal_check_slot (Some ty) slot_id)
  in
G
Graydon Hoare 已提交
559

560
  let infer_lval
561
      ?mut:(mut=Ast.MUT_immutable)
562 563 564
      (ty:Ast.ty)
      (lval:Ast.lval)
      : unit =
565 566
    ignore (generic_check_lval ?mut:mut ~deref:false
              (Some (Ast.TY_mutable ty)) lval)
567
  in
G
Graydon Hoare 已提交
568

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
  (*
   * AST fragment checking
   *)

  let check_binop (binop:Ast.binop) (operand_ty:Ast.ty) =
    match binop with
        Ast.BINOP_eq | Ast.BINOP_ne | Ast.BINOP_lt | Ast.BINOP_le
      | Ast.BINOP_ge | Ast.BINOP_gt ->
          Ast.TY_bool
      | Ast.BINOP_or | Ast.BINOP_and | Ast.BINOP_xor | Ast.BINOP_lsl
      | Ast.BINOP_lsr | Ast.BINOP_asr ->
          demand_integer operand_ty;
          operand_ty
      | Ast.BINOP_add ->
          demand_number_or_str_or_vector operand_ty;
          operand_ty
      | Ast.BINOP_sub | Ast.BINOP_mul | Ast.BINOP_div | Ast.BINOP_mod ->
          demand_number operand_ty;
          operand_ty
      | Ast.BINOP_send ->
          Common.bug () "check_binop: BINOP_send found in expr"
  in
G
Graydon Hoare 已提交
591

592 593 594 595 596 597
  let check_expr (expr:Ast.expr) : Ast.ty =
    match expr with
        Ast.EXPR_atom atom -> check_atom atom
      | Ast.EXPR_binary (binop, lhs, rhs) ->
          let operand_ty = check_atom ~deref:true lhs in
          demand operand_ty (check_atom ~deref:true rhs);
598
          check_binop binop operand_ty
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
      | Ast.EXPR_unary (Ast.UNOP_not, atom) ->
          demand Ast.TY_bool (check_atom ~deref:true atom);
          Ast.TY_bool
      | Ast.EXPR_unary (Ast.UNOP_bitnot, atom)
      | Ast.EXPR_unary (Ast.UNOP_neg, atom) ->
          let ty = check_atom atom in
          demand_integer ty;
          ty
      | Ast.EXPR_unary (Ast.UNOP_cast dst_ty_id, atom) ->
          (* TODO: probably we want to handle more cases here *)
          demand_bool_or_char_or_integer (check_atom atom);
          let dst_ty = dst_ty_id.Common.node in
          demand_bool_or_char_or_integer dst_ty;
          dst_ty
  in
614

615 616 617 618 619
  (* Checks a function call pattern, with arguments specified as atoms, and
   * returns the return type. *)
  let check_fn (callee:Ast.lval) (args:Ast.atom array) : Ast.ty =
    let arg_tys = Array.map check_atom args in
    let callee_ty = check_lval callee in
620
    demand_fn (Array.map (fun ty -> Some ty) arg_tys) callee_ty
621
  in
622

623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
  let rec check_pat (expected:Ast.ty) (pat:Ast.pat) : unit =
    match pat with
        Ast.PAT_lit lit -> demand expected (check_literal lit)
      | Ast.PAT_tag (constr_fn, arg_pats) ->
          let constr_ty = check_lval constr_fn in
          let arg_tys =
            match constr_ty with
                Ast.TY_fn (ty_sig, _) ->
                  Array.map get_slot_ty ty_sig.Ast.sig_input_slots
              | _ -> type_error "constructor function" constr_ty
          in
          Common.arr_iter2 check_pat arg_tys arg_pats
      | Ast.PAT_slot (slot, _) -> infer_slot expected slot.Common.id
      | Ast.PAT_wild -> ()
  in
638

639 640 641 642 643 644
  let check_check_calls (calls:Ast.check_calls) : unit =
    let check_call (callee, args) =
      demand Ast.TY_bool (check_fn callee args)
    in
    Array.iter check_call calls
  in
645

646 647 648
  (*
   * Statement checking
   *)
649

650 651 652
  (* Again as above, we explicitly curry [fn_ctx] to avoid threading it
   * through these functions. *)
  let check_stmt (fn_ctx:fn_ctx) : (Ast.stmt -> unit) =
653 654 655 656
    let check_ret (stmt:Ast.stmt) : unit =
      fn_ctx.fnctx_just_saw_ret <-
        match stmt.Common.node with
            Ast.STMT_ret _ | Ast.STMT_be _ | Ast.STMT_fail
657
          | Ast.STMT_yield -> true
658 659 660
          | _ -> false
    in

661 662
    let rec check_block (block:Ast.block) : unit =
      Array.iter check_stmt block.Common.node
663

664
    and check_stmt (stmt:Ast.stmt) : unit =
665
      check_ret stmt;
666 667 668 669
      match stmt.Common.node with
          Ast.STMT_spawn (dst, _, callee, args) ->
            infer_lval Ast.TY_task dst;
            demand Ast.TY_nil (check_fn callee args)
670

671
        | Ast.STMT_new_rec (dst, fields, Some base) ->
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687
            let ty = check_lval base in
            let ty_rec = demand_rec ty in
            let field_tys = Hashtbl.create (Array.length ty_rec) in
            Array.iter (fun (id, ty) -> Hashtbl.add field_tys id ty) ty_rec;
            let check_field (name, mut, atom) =
              let field_ty =
                try
                  Hashtbl.find field_tys name
                with Not_found ->
                  Common.err None
                    "field '%s' is not one of the base record's fields" name
              in
              demand field_ty (maybe_mutable mut (check_atom atom))
            in
            Array.iter check_field fields;
            infer_lval ty dst
688

689
        | Ast.STMT_new_rec (dst, fields, None) ->
690 691 692 693 694
            let check_field (name, mut, atom) =
              (name, maybe_mutable mut (check_atom atom))
            in
            let ty = Ast.TY_rec (Array.map check_field fields) in
            infer_lval ty dst
695

696
        | Ast.STMT_new_tup (dst, members) ->
697 698 699 700 701
            let check_member (mut, atom) =
              maybe_mutable mut (check_atom atom)
            in
            let ty = Ast.TY_tup (Array.map check_member members) in
            infer_lval ty dst
702

703
        | Ast.STMT_new_vec (dst, mut, [| |]) ->
704 705 706
            (* no inference allowed here *)
            let lval_ty = check_lval ~mut:Ast.MUT_mutable dst in
            ignore (demand_vec_with_mutability mut lval_ty)
707

708
        | Ast.STMT_new_vec (dst, mut, elems) ->
709 710 711
            let atom_ty = demand_all (Array.map check_atom elems) in
            let ty = Ast.TY_vec (maybe_mutable mut atom_ty) in
            infer_lval ty dst
712

713
        | Ast.STMT_new_str (dst, _) -> infer_lval Ast.TY_str dst
714

715 716 717 718
        | Ast.STMT_new_port dst ->
            (* we can't actually typecheck this *)
            ignore (check_lval dst);
            ()
719

720
        | Ast.STMT_new_chan (dst, Some port) ->
721 722
            let ty = Ast.TY_chan (demand_port (check_lval port)) in
            infer_lval ty dst
G
Graydon Hoare 已提交
723

724 725 726 727
        | Ast.STMT_new_chan (dst, None) ->
            (* can't check this either *)
            ignore (check_lval dst);
            ()
G
Graydon Hoare 已提交
728

729
        | Ast.STMT_new_box (dst, mut, src) ->
730 731
            let ty = Ast.TY_box (maybe_mutable mut (check_atom src)) in
            infer_lval ty dst
G
Graydon Hoare 已提交
732

733 734
        | Ast.STMT_copy (dst, src) ->
            infer_lval (check_expr src) dst
G
Graydon Hoare 已提交
735

736 737 738 739
        | Ast.STMT_copy_binop (dst, binop, src) ->
            let ty = check_atom ~deref:true src in
            infer_lval ty dst;
            demand ty (check_binop binop ty)
G
Graydon Hoare 已提交
740

741 742
        | Ast.STMT_call (dst, callee, args) ->
            infer_lval (check_fn callee args) dst
G
Graydon Hoare 已提交
743

744 745 746 747 748 749
        | Ast.STMT_bind (bound, callee, args) ->
            let check_arg arg_opt =
              match arg_opt with
                  None -> None
                | Some arg -> Some (check_atom arg)
            in
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
            let rec replace_args ty =
              match ty with
                  Ast.TY_fn (ty_sig, ty_fn_aux) ->
                    let orig_slots = ty_sig.Ast.sig_input_slots in
                    let take_arg i =
                      match args.(i) with
                          None -> Some orig_slots.(i)
                        | Some _ -> None
                    in
                    let new_slots = Array.init (Array.length args) take_arg in
                    let new_slots = Common.arr_filter_some new_slots in
                    let ty_sig =
                      { ty_sig with Ast.sig_input_slots = new_slots }
                    in
                    Ast.TY_fn (ty_sig, ty_fn_aux)
                | Ast.TY_mutable ty' -> Ast.TY_mutable (replace_args ty')
                | Ast.TY_constrained (ty', constrs) ->
                    Ast.TY_constrained (replace_args ty', constrs)
                | _ -> Common.bug () "replace_args: unexpected type"
            in
770 771
            let callee_ty = check_lval callee in
            ignore (demand_fn (Array.map check_arg args) callee_ty);
772
            infer_lval (replace_args callee_ty) bound
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820

        | Ast.STMT_recv (dst, src) ->
            infer_lval (demand_port (check_lval src)) dst

        | Ast.STMT_slice (dst, src, slice) ->
            let src_ty = check_lval src in
            ignore (demand_vec src_ty);
            infer_lval src_ty dst;
            let check_index index = demand Ast.TY_int (check_atom index) in
            Common.may check_index slice.Ast.slice_start;
            Common.may check_index slice.Ast.slice_len

        | Ast.STMT_while w | Ast.STMT_do_while w ->
            let (stmts, expr) = w.Ast.while_lval in
            Array.iter check_stmt stmts;
            demand Ast.TY_bool (check_expr expr);
            check_block w.Ast.while_body

        | Ast.STMT_for sf ->
            let elem_ty = demand_vec_or_str (check_lval sf.Ast.for_seq) in
            infer_slot elem_ty (fst sf.Ast.for_slot).Common.id;
            check_block sf.Ast.for_body

        | Ast.STMT_for_each sfe ->
            let (callee, args) = sfe.Ast.for_each_call in
            let elem_ty = check_fn callee args in
            infer_slot elem_ty (fst (sfe.Ast.for_each_slot)).Common.id;
            check_block sfe.Ast.for_each_head;
            check_block sfe.Ast.for_each_body

        | Ast.STMT_if si ->
            demand Ast.TY_bool (check_expr si.Ast.if_test);
            check_block si.Ast.if_then;
            Common.may check_block si.Ast.if_else

        | Ast.STMT_put _ when not fn_ctx.fnctx_is_iter ->
            Common.err None "'put' may only be used in an iterator function"

        | Ast.STMT_put (Some atom) ->
            demand fn_ctx.fnctx_return_type (check_atom atom)

        | Ast.STMT_put None -> () (* always well-typed *)

        | Ast.STMT_put_each (callee, args) -> ignore (check_fn callee args)

        | Ast.STMT_ret (Some atom) ->
            if fn_ctx.fnctx_is_iter then
              Common.err None
821
                "iterators can't return values; did you mean 'put'?";
822 823 824 825 826
            demand fn_ctx.fnctx_return_type (check_atom atom)

        | Ast.STMT_ret None ->
            if not fn_ctx.fnctx_is_iter then
              demand Ast.TY_nil fn_ctx.fnctx_return_type
G
Graydon Hoare 已提交
827

828 829
        | Ast.STMT_be (callee, args) ->
            demand fn_ctx.fnctx_return_type (check_fn callee args)
G
Graydon Hoare 已提交
830

831 832 833 834 835
        | Ast.STMT_alt_tag alt_tag ->
            let get_pat arm = fst arm.Common.node in
            let pats = Array.map get_pat alt_tag.Ast.alt_tag_arms in
            let ty = check_lval alt_tag.Ast.alt_tag_lval in
            Array.iter (check_pat ty) pats
G
Graydon Hoare 已提交
836

837
        | Ast.STMT_alt_type _ -> () (* TODO *)
G
Graydon Hoare 已提交
838

839
        | Ast.STMT_alt_port _ -> () (* TODO *)
G
Graydon Hoare 已提交
840

841
        | Ast.STMT_fail | Ast.STMT_yield -> ()  (* always well-typed *)
G
Graydon Hoare 已提交
842

843
        | Ast.STMT_join lval -> infer_lval Ast.TY_task lval
G
Graydon Hoare 已提交
844

845 846 847
        | Ast.STMT_send (chan, value) ->
            let value_ty = demand_chan (check_lval chan) in
            infer_lval ~mut:Ast.MUT_immutable value_ty value
G
Graydon Hoare 已提交
848

849 850 851 852 853 854 855
        | Ast.STMT_log x | Ast.STMT_note x ->
            (* always well-typed, just record type in passing. *)
            ignore (check_atom x)

        | Ast.STMT_prove _ ->
            (* always well-typed, static. Ignore. *)
            ()
G
Graydon Hoare 已提交
856

857
        | Ast.STMT_check (_, calls) -> check_check_calls calls
G
Graydon Hoare 已提交
858

859
        | Ast.STMT_check_expr expr -> demand Ast.TY_bool (check_expr expr)
G
Graydon Hoare 已提交
860

861 862 863
        | Ast.STMT_check_if (_, calls, block) ->
            check_check_calls calls;
            check_block block
864

865
        | Ast.STMT_block block -> check_block block
866

867
        | Ast.STMT_decl _ -> () (* always well-typed *)
G
Graydon Hoare 已提交
868 869
    in

870 871 872 873 874 875 876 877 878
    let check_stmt' stmt =
      try
        check_stmt stmt
      with Type_error (expected, actual) ->
        Common.err
          (Some stmt.Common.id)
          "mismatched types: expected %s but found %a"
          expected
          Ast.sprintf_ty actual
G
Graydon Hoare 已提交
879
    in
880 881 882
    check_stmt'
  in
  check_stmt
G
Graydon Hoare 已提交
883

884 885 886
let process_crate (cx:Semant.ctxt) (crate:Ast.crate) : unit =
  let path = Stack.create () in
  let fn_ctx_stack = Stack.create () in
G
Graydon Hoare 已提交
887

888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903
  (* Verify that, if main is present, it has the right form. *)
  let verify_main (item_id:Common.node_id) : unit =
    let path_name = Semant.string_of_name (Semant.path_to_name path) in
    if cx.Semant.ctxt_main_name = Some path_name then
      try
        match Hashtbl.find cx.Semant.ctxt_all_item_types item_id with
            Ast.TY_fn ({ Ast.sig_input_slots = [| |] }, _)
          | Ast.TY_fn ({ Ast.sig_input_slots = [| {
                Ast.slot_mode = Ast.MODE_local;
                Ast.slot_ty = Some (Ast.TY_vec Ast.TY_str)
              } |] }, _) ->
            ()
          | _ -> Common.err (Some item_id) "main fn has bad type signature"
      with Not_found ->
        Common.err (Some item_id) "main item has no type (is it a function?)"
  in
G
Graydon Hoare 已提交
904

905 906
  let visitor (cx:Semant.ctxt) (inner:Walk.visitor) : Walk.visitor =
    let push_fn_ctx (ret_ty:Ast.ty) (is_iter:bool) =
907 908 909 910 911
      let fn_ctx = {
        fnctx_return_type = ret_ty;
        fnctx_is_iter = is_iter;
        fnctx_just_saw_ret = false
      } in
912 913
      Stack.push fn_ctx fn_ctx_stack
    in
G
Graydon Hoare 已提交
914

915 916 917 918 919 920 921
    let push_fn_ctx_of_ty_fn (ty_fn:Ast.ty_fn) : unit =
      let (ty_sig, ty_fn_aux) = ty_fn in
      let ret_ty = ty_sig.Ast.sig_output_slot.Ast.slot_ty in
      let is_iter = ty_fn_aux.Ast.fn_is_iter in
      push_fn_ctx (Common.option_get ret_ty) is_iter
    in

922 923 924 925 926 927 928 929
    let finish_function (item_id:Common.node_id) =
      let fn_ctx = Stack.pop fn_ctx_stack in
      if not fn_ctx.fnctx_just_saw_ret &&
          fn_ctx.fnctx_return_type <> Ast.TY_nil &&
          not fn_ctx.fnctx_is_iter then
        Common.err (Some item_id) "this function must return a value"
    in

930
    let visit_mod_item_pre _ _ item =
931 932
      let { Common.node = item; Common.id = item_id } = item in
      match item.Ast.decl_item with
933 934
          Ast.MOD_ITEM_fn _ when
              not (Hashtbl.mem cx.Semant.ctxt_required_items item_id) ->
935
            let fn_ty = Hashtbl.find cx.Semant.ctxt_all_item_types item_id in
G
Graydon Hoare 已提交
936
            begin
937 938
              match fn_ty with
                  Ast.TY_fn ty_fn -> push_fn_ctx_of_ty_fn ty_fn
939
                | _ ->
940 941
                  Common.bug ()
                    "Type.visit_mod_item_pre: fn item didn't have a fn type"
942
            end
943 944 945
        | _ -> ()
    in
    let visit_mod_item_post _ _ item =
946 947
      let item_id = item.Common.id in
      verify_main item_id;
948
      match item.Common.node.Ast.decl_item with
949 950 951
          Ast.MOD_ITEM_fn _ when
              not (Hashtbl.mem cx.Semant.ctxt_required_items item_id) ->
            finish_function item_id
952 953
        | _ -> ()
    in
954

955 956 957 958 959 960 961 962 963 964
    let visit_obj_fn_pre obj ident _ =
      let obj_ty = Hashtbl.find cx.Semant.ctxt_all_item_types obj.Common.id in
      match obj_ty with
          Ast.TY_fn ({ Ast.sig_output_slot =
              { Ast.slot_ty = Some (Ast.TY_obj (_, methods)) } }, _) ->
            push_fn_ctx_of_ty_fn (Hashtbl.find methods ident)
        | _ ->
            Common.bug ()
              "Type.visit_obj_fn_pre: item doesn't have an object type (%a)"
              Ast.sprintf_ty obj_ty
G
Graydon Hoare 已提交
965
    in
966
    let visit_obj_fn_post _ _ item = finish_function (item.Common.id) in
G
Graydon Hoare 已提交
967

968 969 970 971 972
    let visit_obj_drop_pre _ _ = push_fn_ctx Ast.TY_nil false in
    let visit_obj_drop_post _ _ = ignore (Stack.pop fn_ctx_stack) in

    (* TODO: make sure you can't fall off the end of a function if it doesn't
     * return void *)
G
Graydon Hoare 已提交
973 974
    let visit_stmt_pre (stmt:Ast.stmt) : unit =
      try
975 976 977
        log cx "";
        log cx "typechecking stmt: %a" Ast.sprintf_stmt stmt;
        log cx "";
978
        check_stmt cx (Stack.top fn_ctx_stack) stmt;
979
        log cx "finished typechecking stmt: %a" Ast.sprintf_stmt stmt;
980 981
      with Common.Semant_err (None, msg) ->
        raise (Common.Semant_err ((Some stmt.Common.id), msg))
G
Graydon Hoare 已提交
982 983
    in

984 985 986 987 988 989 990
    let visit_crate_post _ : unit =
      (* Fill in the autoderef info for any lvals we didn't get to. *)
      let fill lval_id _ =
        if not (Hashtbl.mem cx.Semant.ctxt_auto_deref_lval lval_id) then
          Hashtbl.add cx.Semant.ctxt_auto_deref_lval lval_id false
      in
      Hashtbl.iter fill cx.Semant.ctxt_all_lvals
991 992
    in

993 994 995 996 997 998 999 1000 1001 1002 1003
    {
      inner with
        Walk.visit_stmt_pre = visit_stmt_pre;
        Walk.visit_mod_item_pre = visit_mod_item_pre;
        Walk.visit_mod_item_post = visit_mod_item_post;
        Walk.visit_obj_fn_pre = visit_obj_fn_pre;
        Walk.visit_obj_fn_post = visit_obj_fn_post;
        Walk.visit_obj_drop_pre = visit_obj_drop_pre;
        Walk.visit_obj_drop_post = visit_obj_drop_post;
        Walk.visit_crate_post = visit_crate_post
    }
G
Graydon Hoare 已提交
1004 1005
  in

1006 1007 1008 1009 1010 1011 1012 1013
  let passes =
    [|
      (visitor cx Walk.empty_visitor)
    |]
  in
  let log_flag = cx.Semant.ctxt_sess.Session.sess_log_type in
    Semant.run_passes cx "type" path passes log_flag log crate
;;
G
Graydon Hoare 已提交
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023

(*
 * Local Variables:
 * fill-column: 78;
 * indent-tabs-mode: nil
 * buffer-file-coding-system: utf-8-unix
 * compile-command: "make -k -C ../.. 2>&1 | sed -e 's/\\/x\\//x:\\//g'";
 * End:
 *)