divisor.layer_dropout
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4from typing import Any, Callable 5 6from nnll.console import nfo 7from torch.nn import ModuleList 8 9 10def process_blocks_with_dropout( 11 blocks: ModuleList, 12 layer_dropouts: list[int] | None, 13 start_index: int, 14 block_type: str, 15 process_block: Callable[[Any, Any], Any], 16 state: Any, 17) -> Any: 18 """Process blocks, skipping layer dropout blocks\n 19 :param blocks: List of blocks to process 20 :param layer_dropouts: List of block indices to skip (None to skip none) 21 :param start_index: Starting index for layer dropout checking (for offset) 22 :param block_type: Type of block ("double" or "single") for logging 23 :param process_block: Callable that takes (block, current_state) and returns updated state 24 :param state: Initial state to process (e.g., (img, txt) for double, img for single) 25 :return: The final state after processing all non-dropped blocks 26 """ 27 for block_index, block in enumerate(blocks): 28 global_index = start_index + block_index 29 if layer_dropouts is not None and global_index in layer_dropouts: 30 nfo(f"Dropping layer {global_index} ({block_type} block)") 31 continue 32 33 state = process_block(block, state) 34 35 return state
def
process_blocks_with_dropout( blocks: torch.nn.modules.container.ModuleList, layer_dropouts: list[int] | None, start_index: int, block_type: str, process_block: Callable[[Any, Any], Any], state: Any) -> Any:
11def process_blocks_with_dropout( 12 blocks: ModuleList, 13 layer_dropouts: list[int] | None, 14 start_index: int, 15 block_type: str, 16 process_block: Callable[[Any, Any], Any], 17 state: Any, 18) -> Any: 19 """Process blocks, skipping layer dropout blocks\n 20 :param blocks: List of blocks to process 21 :param layer_dropouts: List of block indices to skip (None to skip none) 22 :param start_index: Starting index for layer dropout checking (for offset) 23 :param block_type: Type of block ("double" or "single") for logging 24 :param process_block: Callable that takes (block, current_state) and returns updated state 25 :param state: Initial state to process (e.g., (img, txt) for double, img for single) 26 :return: The final state after processing all non-dropped blocks 27 """ 28 for block_index, block in enumerate(blocks): 29 global_index = start_index + block_index 30 if layer_dropouts is not None and global_index in layer_dropouts: 31 nfo(f"Dropping layer {global_index} ({block_type} block)") 32 continue 33 34 state = process_block(block, state) 35 36 return state
Process blocks, skipping layer dropout blocks
Parameters
- blocks: List of blocks to process
- layer_dropouts: List of block indices to skip (None to skip none)
- start_index: Starting index for layer dropout checking (for offset)
- block_type: Type of block ("double" or "single") for logging
- process_block: Callable that takes (block, current_state) and returns updated state
- state: Initial state to process (e.g., (img, txt) for double, img for single)
Returns
The final state after processing all non-dropped blocks