divisor.layer_dropout

 1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
 2# <!-- // /*  d a r k s h a p e s */ -->
 3
 4from typing import Any, Callable
 5
 6from nnll.console import nfo
 7from torch.nn import ModuleList
 8
 9
10def process_blocks_with_dropout(
11    blocks: ModuleList,
12    layer_dropouts: list[int] | None,
13    start_index: int,
14    block_type: str,
15    process_block: Callable[[Any, Any], Any],
16    state: Any,
17) -> Any:
18    """Process blocks, skipping layer dropout blocks\n
19    :param blocks: List of blocks to process
20    :param layer_dropouts: List of block indices to skip (None to skip none)
21    :param start_index: Starting index for layer dropout checking (for offset)
22    :param block_type: Type of block ("double" or "single") for logging
23    :param process_block: Callable that takes (block, current_state) and returns updated state
24    :param state: Initial state to process (e.g., (img, txt) for double, img for single)
25    :return: The final state after processing all non-dropped blocks
26    """
27    for block_index, block in enumerate(blocks):
28        global_index = start_index + block_index
29        if layer_dropouts is not None and global_index in layer_dropouts:
30            nfo(f"Dropping layer {global_index} ({block_type} block)")
31            continue
32
33        state = process_block(block, state)
34
35    return state
def process_blocks_with_dropout( blocks: torch.nn.modules.container.ModuleList, layer_dropouts: list[int] | None, start_index: int, block_type: str, process_block: Callable[[Any, Any], Any], state: Any) -> Any:
11def process_blocks_with_dropout(
12    blocks: ModuleList,
13    layer_dropouts: list[int] | None,
14    start_index: int,
15    block_type: str,
16    process_block: Callable[[Any, Any], Any],
17    state: Any,
18) -> Any:
19    """Process blocks, skipping layer dropout blocks\n
20    :param blocks: List of blocks to process
21    :param layer_dropouts: List of block indices to skip (None to skip none)
22    :param start_index: Starting index for layer dropout checking (for offset)
23    :param block_type: Type of block ("double" or "single") for logging
24    :param process_block: Callable that takes (block, current_state) and returns updated state
25    :param state: Initial state to process (e.g., (img, txt) for double, img for single)
26    :return: The final state after processing all non-dropped blocks
27    """
28    for block_index, block in enumerate(blocks):
29        global_index = start_index + block_index
30        if layer_dropouts is not None and global_index in layer_dropouts:
31            nfo(f"Dropping layer {global_index} ({block_type} block)")
32            continue
33
34        state = process_block(block, state)
35
36    return state

Process blocks, skipping layer dropout blocks

Parameters
  • blocks: List of blocks to process
  • layer_dropouts: List of block indices to skip (None to skip none)
  • start_index: Starting index for layer dropout checking (for offset)
  • block_type: Type of block ("double" or "single") for logging
  • process_block: Callable that takes (block, current_state) and returns updated state
  • state: Initial state to process (e.g., (img, txt) for double, img for single)
Returns

The final state after processing all non-dropped blocks