package neural_nets_lib

  1. Overview
  2. Docs

Construction of runtime-compiled code supporting backpropagation.

type init_op = Arrayjit.Ops.init_op
type diff = {
  1. grad : tn;
  2. zero_grads : asgns;
    (*

    Prepares for backpropagation. Always compile as: Seq (zero_grads, backprop).

    *)
  3. backprop : asgns;
    (*

    Backpropagates for the tensor and its descendants; which typically means adding partial gradients to the gradient tensor of the subtensors, then for sub-subtensors etc.

    *)
}
type t = {
  1. forward : asgns;
  2. diff : diff Base.option;
  3. id : Base.int;
    (*

    Same as value.id.

    *)
  4. value : tn;
  5. shape : Shape.t;
    (*

    The eventual shape of t.value and t.diff.grad, incorporating the current state of shape inference.

    *)
  6. children : subtensor Base.list;
}

Information needed for compositional code generation.

and subtensor = {
  1. subtensor : t;
  2. embedded : Base.bool;
}
val sexp_of_t : t -> Sexplib0.Sexp.t
val sexp_of_subtensor : subtensor -> Sexplib0.Sexp.t
type comparator_witness
val comparator : (t, comparator_witness) Base.Comparator.t
val is_fwd_root : t -> Base.bool
val remove_fwd_root : t -> Base.unit
val is_bprop_root : t -> Base.bool
val remove_bprop_root : t -> Base.unit
val with_unchanged_roots : f:(Base.unit -> 'a) -> 'a
val default_value_prec : Arrayjit.Ops.prec Base.ref
val default_grad_prec : Arrayjit.Ops.prec Base.ref
exception Session_error of Base.string * t Base.option
val max_sublabel_length : Base.int Base.ref
val raw_binop : initialize_neutral:Base.bool -> accum:Arrayjit.Ops.binop -> t:t -> lhs_is_grad:Base.bool -> op:Arrayjit.Ops.binop -> t1:t -> rhs1_is_grad:Base.bool -> t2:t -> rhs2_is_grad:Base.bool -> logic:Shape.compose_type -> asgns
val raw_unop : initialize_neutral:Base.bool -> accum:Arrayjit.Ops.binop -> t:t -> lhs_is_grad:Base.bool -> op:Arrayjit.Ops.unop -> t1:t -> rhs_is_grad:Base.bool -> logic:Shape.transpose_type -> asgns
type grad_spec =
  1. | Require_grad
  2. | Prohibit_grad
  3. | If_needed
val is_prohibit_grad : grad_spec -> Base.bool
val op : label:Base.string Base.list -> ?compose_op:Shape.compose_type -> ?transpose_op:Shape.transpose_type -> ?init_op:init_op -> op_asn:(v:tn -> projections:projections Base.Lazy.t -> asgns) -> grad_asn:(v:tn -> g:tn -> projections:projections Base.Lazy.t -> asgns) -> ?grad_spec:grad_spec -> (debug_name:Base.string -> id:Base.int -> Shape.t) -> t Base.list -> t
val binop : label:Base.string Base.list -> ?compose_op:Shape.compose_type -> op_asn:(v:tn -> t1:t -> t2:t -> projections:projections Base.Lazy.t -> asgns) -> grad_asn: (v:tn -> g:tn -> t1:t -> t2:t -> projections:projections Base.Lazy.t -> asgns) -> ?grad_spec:grad_spec -> t -> t -> t
val unop : label:Base.string Base.list -> ?transpose_op:Shape.transpose_type -> op_asn:(v:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) -> grad_asn: (v:tn -> g:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) -> ?grad_spec:grad_spec -> t -> t
val term : label:Base.string Base.list -> grad_spec:grad_spec -> ?batch_dims:Base.int Base.list -> ?input_dims:Base.int Base.list -> ?output_dims:Base.int Base.list -> ?batch_axes:(Base.string * Base.int) Base.list -> ?input_axes:(Base.string * Base.int) Base.list -> ?output_axes:(Base.string * Base.int) Base.list -> ?deduced:Shape.deduce_within_shape -> ?init_op:init_op -> ?fetch_op:(v:tn -> fetch_op) -> Base.unit -> t

A terminal: a constant, a parameter, an input of the model. The semantics of shape specification is the same as in Shape.make, and by default the shape will be inferred.

val number : ?label:Base.string Base.list -> ?axis_label:Base.string -> ?grad_spec:grad_spec -> Base.float -> t

A number: a tensor with a single axis of one dimension, initialized to the given value. grad_spec is by default Prohibit_grad.

val ndarray : ?label:Base.string Base.list -> ?grad_spec:grad_spec -> ?batch_dims:Base.int Base.list -> ?input_dims:Base.int Base.list -> ?output_dims:Base.int Base.list -> ?batch_axes:(Base.string * Base.int) Base.list -> ?input_axes:(Base.string * Base.int) Base.list -> ?output_axes:(Base.string * Base.int) Base.list -> ?strict:Base.bool -> Base.float Base.array -> t

A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to no axes. grad_spec is by default Prohibit_grad. If strict is true (the default), the given values must fill the tensor's value node precisely; otherwise, the values will be looped over to populate the value node.

val param : ?input_dims:Base.int Base.list -> ?output_dims:Base.int Base.list -> ?input_axes:(Base.string * Base.int) Base.list -> ?output_axes:(Base.string * Base.int) Base.list -> ?deduced:Shape.deduce_within_shape -> ?strict:Base.bool -> ?values:Base.float Base.array -> Base.string -> t
val iter_embedded_arrays : f:(tn -> Base.unit) -> t -> Base.unit
val consume_forward_code : t -> asgns

A forward root is a tensor that is not (currently) used to compute another tensor. consume_forward_code t ensures t is a forward root, removes it from forward roots, and checks that there are no other forward roots for tensors with children.

val consume_backprop_code : t -> asgns * asgns

A backprop root is a tensor with a gradient that is not (currently) receiving gradients from another tensor. I.e. it is not currently used to compute a tensor with a gradient. consume_backprop_code t ensures t is a backprop root, removes it from backprop roots, and checks that there are no other backprop roots for tensors with children.

Printing.

val header : t -> Base.string

Converts ID, label and the dimensions of a node to a string.

type array_print_style = [
  1. | `Default
    (*

    The inner rectangles comprise both an input and an output axis, if available. Similarly, the outer rectangle comprises a second-from-end input axis and a second-from-end output axis, if available. At least one batch axis is output, when available. The axes that couldn't be output are printed at position/dimension 0.

    *)
  2. | `N5_layout of Base.string
    (*

    The string should provide exclusively non-negative integer pseudo-labels. The numbers 0-4 represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (see also Node.pp_print). The numbers n >= 5 stand for the actual positions n - 5 within the corresponding axes.

    *)
  3. | `Label_layout of (Base.string * Base.int) Base.list
    (*

    The association from axis labels to integers. The negative numbers -5 to -1 represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (as above). The numbers n >= 0 stand for the actual positions within the corresponding axes. Unspecified axes are printed at position 0.

    *)
  4. | `Inline
    (*

    The tensors are printed linearly, in a bracketed manner, optionally prefixed with the labels specification. Note that the syntax causes ambiguity for 1-dimensional input axes (underscores are used for axes without explicit labels); when there is a 1-dimensional input axis, we output the labels specification even if there are no axis labels as a way to display the number of axes. The axis nesting is right-to-left (rightmost is innermost). The input axes are innermost and the batch axes outermost. The input axes use , as a separator and () as axis delimiters, but the delimiter for the outermost (i.e. leftmost) axis is omitted. The output axes use ; as a separator and [] as axis delimiters (obligatory). The batch axes use ; as a separator and [||] as axis delimiters (obligatory).

    *)
]

We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner) rectangles, possibly repeated (screens).

val to_printbox : ?single_node:Base.bool -> ?entries_per_axis:Base.int -> ?with_id:Base.bool -> ?with_shape:Base.bool -> ?with_value:Base.bool -> with_grad:Base.bool -> depth:Base.int -> t -> PrintBox.t
val print : with_grad:Base.bool -> with_code:Base.bool -> ?force:Base.bool -> ?with_low_level:Base.bool -> array_print_style -> t -> Base.unit
val print_forward_roots : with_grad:Base.bool -> with_code:Base.bool -> array_print_style -> Base.unit
val print_tree : ?entries_per_axis:Base.int -> ?with_backend_info:Base.bool -> ?with_id:Base.bool -> ?with_shape:Base.bool -> ?with_value:Base.bool -> with_grad:Base.bool -> depth:Base.int -> t -> Base.unit

Accessors.

val value_1d_points : ?from_axis:Base.int -> xdim:Base.int -> t -> Base.float Base.array
val value_2d_points : ?from_axis:Base.int -> xdim:Base.int -> ydim:Base.int -> t -> (Base.float * Base.float) Base.array
val grad_1d_points : ?from_axis:Base.int -> xdim:Base.int -> t -> Base.float Base.array
val grad_2d_points : ?from_axis:Base.int -> xdim:Base.int -> ydim:Base.int -> t -> (Base.float * Base.float) Base.array
val set_value : t -> Base.int Base.array -> Base.float -> Base.unit
val get_value : t -> Base.int Base.array -> Base.float
val set_grad : t -> Base.int Base.array -> Base.float -> Base.unit
val get_grad : t -> Base.int Base.array -> Base.float
val set_values : t -> Base.float Base.array -> Base.unit
val get_values : t -> Base.float Base.array
OCaml

Innovation. Community. Security.