Skip to content

utils

Core utility functions for codablellm.

Command = List[str] module-attribute

A CLI command.

CommandErrorHandler = Literal['interactive', 'ignore', 'none'] module-attribute

Defines the strategies for handling errors encountered during the execution of a CLI command.

Supported Error Handlers
  • ignore: The CLI command error is ignored, and execution continues without interruption.
  • none: An exception is raised immediately upon encountering the CLI error.
  • interactive: The user is prompted to resolve the error manually, allowing for interactive handling of the issue.

JSONObject = Dict[str, JSONValue] module-attribute

Represents a JSON object.

JSONValue = Optional[Union[str, int, float, bool, List['JSONValue'], 'JSONObject']] module-attribute

Represents a valid JSON value

PathLike = Union[Path, str] module-attribute

An object representing a file system path.

REBASED_DIR_ENVIRON_KEY = 'CODABLELLM_REBASED_DIR' module-attribute

Environment variable key used to expose the rebased directory path to subprocesses.

This is especially useful when running custom build or clean commands that need to reference the rebased project root dynamically (e.g., using shell expansion like $CODABLELLM_REBASED_DIR).

Set automatically when using the temp generation mode.

ASTEditor

A Tree-sitter AST editor.

Source code in src/codablellm/core/utils.py
class ASTEditor:
    """
    A Tree-sitter AST editor.
    """

    def __init__(
        self, parser: Parser, source_code: str, ensure_parsable: bool = True
    ) -> None:
        """
        Initializes the AST editor with a parser and source code.

        Parameters:
            parser: The `Parser` object used to parse the source code.
            source_code: The source code to be edited.
            ensure_parsable: If `True`, raises an error if edits result in an invalid AST.
        """
        self.parser = parser
        self.source_code = source_code
        self.ast = self.parser.parse(source_code.encode())
        self.ensure_parsable = ensure_parsable

    def edit_code(self, node: Node, new_code: str) -> None:
        """
        Edits the source code at the specified AST node and updates the AST.

        Parameters:
            node: The `Node` object representing the code to replace.
            new_code: The new code to insert in place of the node's source code.

        Raises:
            TSParsingError: If `ensure_parsable` is `True` and the resulting AST has parsing errors.
        """
        # Calculate new code metrics
        num_bytes = len(new_code)
        num_lines = new_code.count("\n")
        last_col_num_bytes = len(new_code.splitlines()[-1])
        # Update the source code with the new code
        self.source_code = (
            self.source_code[: node.start_byte]
            + new_code
            + self.source_code[node.end_byte :]
        )
        # Perform the AST edit
        self.ast.edit(
            start_byte=node.start_byte,
            old_end_byte=node.end_byte,
            new_end_byte=node.start_byte + num_bytes,
            start_point=node.start_point,
            old_end_point=node.end_point,
            new_end_point=(
                node.start_point.row + num_lines,
                node.start_point.column + last_col_num_bytes,
            ),
        )
        # Re-parse the updated source code
        self.ast = self.parser.parse(self.source_code.encode(), old_tree=self.ast)
        # Check for parsing errors if required
        if self.ensure_parsable and self.ast.root_node.has_error:
            raise TSParsingError("Parsing error while editing code")

    def match_and_edit(
        self,
        query: str,
        groups_and_replacement: Dict[str, Union[str, Callable[[Node], str]]],
    ) -> None:
        """
        Searches the AST using a Tree-sitter query and applies code edits to matching nodes.

        For each match group, replaces the matched node's code with a provided string or the
        result of a callable that returns the replacement string.

        Parameters:
            query: The Tree-sitter query string to use for finding matching nodes.
            groups_and_replacement: A mapping from query group names to either replacement strings
                                    or callables that take a `Node` and return a replacement string.

        Raises:
            TSParsingError: If an edit introduces parsing errors and `ensure_parsable` is `True`.
        """
        modified_nodes: Set[Node] = set()
        matches = self.ast.language.query(query).matches(self.ast.root_node)
        for idx in range(len(matches)):
            _, capture = matches.pop(idx)
            for group, replacement in groups_and_replacement.items():
                nodes = capture.get(group)
                if nodes:
                    node = nodes.pop()
                    if node not in modified_nodes:
                        if not isinstance(replacement, str):
                            replacement = replacement(node)
                        self.edit_code(node, replacement)
                        modified_nodes.add(node)
                        matches = self.ast.language.query(query).matches(
                            self.ast.root_node
                        )
                        break

__init__(parser, source_code, ensure_parsable=True)

Initializes the AST editor with a parser and source code.

Parameters:

Name Type Description Default
parser Parser

The Parser object used to parse the source code.

required
source_code str

The source code to be edited.

required
ensure_parsable bool

If True, raises an error if edits result in an invalid AST.

True
Source code in src/codablellm/core/utils.py
def __init__(
    self, parser: Parser, source_code: str, ensure_parsable: bool = True
) -> None:
    """
    Initializes the AST editor with a parser and source code.

    Parameters:
        parser: The `Parser` object used to parse the source code.
        source_code: The source code to be edited.
        ensure_parsable: If `True`, raises an error if edits result in an invalid AST.
    """
    self.parser = parser
    self.source_code = source_code
    self.ast = self.parser.parse(source_code.encode())
    self.ensure_parsable = ensure_parsable

edit_code(node, new_code)

Edits the source code at the specified AST node and updates the AST.

Parameters:

Name Type Description Default
node Node

The Node object representing the code to replace.

required
new_code str

The new code to insert in place of the node's source code.

required

Raises:

Type Description
TSParsingError

If ensure_parsable is True and the resulting AST has parsing errors.

Source code in src/codablellm/core/utils.py
def edit_code(self, node: Node, new_code: str) -> None:
    """
    Edits the source code at the specified AST node and updates the AST.

    Parameters:
        node: The `Node` object representing the code to replace.
        new_code: The new code to insert in place of the node's source code.

    Raises:
        TSParsingError: If `ensure_parsable` is `True` and the resulting AST has parsing errors.
    """
    # Calculate new code metrics
    num_bytes = len(new_code)
    num_lines = new_code.count("\n")
    last_col_num_bytes = len(new_code.splitlines()[-1])
    # Update the source code with the new code
    self.source_code = (
        self.source_code[: node.start_byte]
        + new_code
        + self.source_code[node.end_byte :]
    )
    # Perform the AST edit
    self.ast.edit(
        start_byte=node.start_byte,
        old_end_byte=node.end_byte,
        new_end_byte=node.start_byte + num_bytes,
        start_point=node.start_point,
        old_end_point=node.end_point,
        new_end_point=(
            node.start_point.row + num_lines,
            node.start_point.column + last_col_num_bytes,
        ),
    )
    # Re-parse the updated source code
    self.ast = self.parser.parse(self.source_code.encode(), old_tree=self.ast)
    # Check for parsing errors if required
    if self.ensure_parsable and self.ast.root_node.has_error:
        raise TSParsingError("Parsing error while editing code")

match_and_edit(query, groups_and_replacement)

Searches the AST using a Tree-sitter query and applies code edits to matching nodes.

For each match group, replaces the matched node's code with a provided string or the result of a callable that returns the replacement string.

Parameters:

Name Type Description Default
query str

The Tree-sitter query string to use for finding matching nodes.

required
groups_and_replacement Dict[str, Union[str, Callable[[Node], str]]]

A mapping from query group names to either replacement strings or callables that take a Node and return a replacement string.

required

Raises:

Type Description
TSParsingError

If an edit introduces parsing errors and ensure_parsable is True.

Source code in src/codablellm/core/utils.py
def match_and_edit(
    self,
    query: str,
    groups_and_replacement: Dict[str, Union[str, Callable[[Node], str]]],
) -> None:
    """
    Searches the AST using a Tree-sitter query and applies code edits to matching nodes.

    For each match group, replaces the matched node's code with a provided string or the
    result of a callable that returns the replacement string.

    Parameters:
        query: The Tree-sitter query string to use for finding matching nodes.
        groups_and_replacement: A mapping from query group names to either replacement strings
                                or callables that take a `Node` and return a replacement string.

    Raises:
        TSParsingError: If an edit introduces parsing errors and `ensure_parsable` is `True`.
    """
    modified_nodes: Set[Node] = set()
    matches = self.ast.language.query(query).matches(self.ast.root_node)
    for idx in range(len(matches)):
        _, capture = matches.pop(idx)
        for group, replacement in groups_and_replacement.items():
            nodes = capture.get(group)
            if nodes:
                node = nodes.pop()
                if node not in modified_nodes:
                    if not isinstance(replacement, str):
                        replacement = replacement(node)
                    self.edit_code(node, replacement)
                    modified_nodes.add(node)
                    matches = self.ast.language.query(query).matches(
                        self.ast.root_node
                    )
                    break

SupportsJSON

Bases: Protocol

A class that supports JSON serialization/deserialization.

Source code in src/codablellm/core/utils.py
class SupportsJSON(Protocol):
    """
    A class that supports JSON serialization/deserialization.
    """

    def to_json(self) -> JSONObject_T:  # type: ignore
        """
        Serializes this object to a JSON object.

        Returns:
            A JSON representation of the object.
        """
        ...

    @classmethod
    def from_json(cls: Type[SupportsJSON_T], json_obj: JSONObject_T) -> SupportsJSON_T:  # type: ignore
        """
        Deserializes a JSON object to this object.

        Parameters:
            json_obj: The JSON representation of this object.

        Returns:
            This object loaded from the JSON object.
        """
        ...

from_json(json_obj) classmethod

Deserializes a JSON object to this object.

Parameters:

Name Type Description Default
json_obj JSONObject_T

The JSON representation of this object.

required

Returns:

Type Description
SupportsJSON_T

This object loaded from the JSON object.

Source code in src/codablellm/core/utils.py
@classmethod
def from_json(cls: Type[SupportsJSON_T], json_obj: JSONObject_T) -> SupportsJSON_T:  # type: ignore
    """
    Deserializes a JSON object to this object.

    Parameters:
        json_obj: The JSON representation of this object.

    Returns:
        This object loaded from the JSON object.
    """
    ...

to_json()

Serializes this object to a JSON object.

Returns:

Type Description
JSONObject_T

A JSON representation of the object.

Source code in src/codablellm/core/utils.py
def to_json(self) -> JSONObject_T:  # type: ignore
    """
    Serializes this object to a JSON object.

    Returns:
        A JSON representation of the object.
    """
    ...

add_command_args(command, *args)

Appends additional arguments to a CLI command.

Parameters:

Name Type Description Default
command Command

The CLI command to append.

required
args Any

Additional arguments to append to the command.

()

Returns:

Type Description
Command

The updated command with the appended arguments.

Source code in src/codablellm/core/utils.py
def add_command_args(command: Command, *args: Any) -> Command:
    """
    Appends additional arguments to a CLI command.

    Parameters:
        command: The CLI command to append.
        args: Additional arguments to append to the command.

    Returns:
        The updated command with the appended arguments.
    """
    command = [command] if isinstance(command, str) else command
    return [*command, *args]

execute_command(command, error_handler='none', task=None, ctx=nullcontext(), log_level='info', print_errors=True, cwd=None)

Executes a CLI command with optional interactive error handling.

Parameters:

Name Type Description Default
command Command

The CLI command to be executed.

required
error_handler CommandErrorHandler

'none' | 'interactive'

'none'
task Optional[str]

Optional description for logging.

None
ctx AbstractContextManager[Any]

Context manager used to wrap the execution.

nullcontext()
log_level Literal['debug', 'info']

Log level for the task description.

'info'
print_errors bool

If True, prints output on error.

True
cwd Optional[PathLike]

Working directory to execute the command in.

None

Returns:

Type Description
str

The output of the command.

Raises:

Type Description
CalledProcessError

If the command fails and error_handler is 'none'.

Source code in src/codablellm/core/utils.py
def execute_command(
    command: Command,
    error_handler: CommandErrorHandler = "none",
    task: Optional[str] = None,
    ctx: AbstractContextManager[Any] = nullcontext(),
    log_level: Literal["debug", "info"] = "info",
    print_errors: bool = True,
    cwd: Optional[PathLike] = None,
) -> str:
    """
    Executes a CLI command with optional interactive error handling.

    Parameters:
        command: The CLI command to be executed.
        error_handler: 'none' | 'interactive'
        task: Optional description for logging.
        ctx: Context manager used to wrap the execution.
        log_level: Log level for the task description.
        print_errors: If True, prints output on error.
        cwd: Working directory to execute the command in.

    Returns:
        The output of the command.

    Raises:
        CalledProcessError: If the command fails and error_handler is 'none'.
    """
    if isinstance(command, str):
        command = command.split()
    log_task = logger.debug if log_level == "debug" else logger.info
    output = ""

    if not task:
        task = f"Executing: {repr(command)}"

    while True:
        if task:
            log_task(task)

        try:
            with ctx:
                output = subprocess.check_output(
                    command, text=True, cwd=cwd, stderr=subprocess.STDOUT
                )
            log_task(f"Successfully executed {repr(command)}")
            break  # Exit loop on success

        except subprocess.CalledProcessError as e:
            output = e.output
            logger.error(f"Command failed: {repr(command)}")
            if print_errors:
                print(f"[red][b]Command failed: {repr(command)}[/b]\nOutput: {output}")

            if error_handler == "interactive":
                result = Prompt.ask(
                    "A command error occurred. You can manually fix the issue and retry, ignore the error to continue, "
                    "abort the process, or edit the command. How would you like to proceed?",
                    choices=["retry", "ignore", "abort", "edit"],
                    case_sensitive=False,
                    default="retry",
                )

                if result == "retry":
                    continue
                elif result == "ignore":
                    break
                elif result == "abort":
                    raise e
                elif result == "edit":
                    command_str = (
                        command
                        if isinstance(command, str)
                        else " ".join(str(c) for c in command)
                    )
                    edited_command = Prompt.ask(
                        "Enter the new command to execute", default=f'"{command_str}"'
                    ).strip("\"'")
                    command = (
                        edited_command
                        if isinstance(edited_command, list)
                        else edited_command.split()
                    )
                    continue

            # If not interactive, raise immediately
            raise

    if output:
        logger.debug(f'{repr(command)} output:\n"{output}"')
    return output

get_checkpoint_file(prefix)

Returns the checkpoint file path for the current process based on the given prefix.

The checkpoint file is stored in the system temporary directory and named using the format: {prefix}_{pid}.json.

Parameters:

Name Type Description Default
prefix str

The filename prefix for the checkpoint file.

required

Returns:

Type Description
Path

A Path object pointing to the checkpoint file.

Source code in src/codablellm/core/utils.py
def get_checkpoint_file(prefix: str) -> Path:
    """
    Returns the checkpoint file path for the current process based on the given prefix.

    The checkpoint file is stored in the system temporary directory and named using
    the format: `{prefix}_{pid}.json`.

    Parameters:
        prefix: The filename prefix for the checkpoint file.

    Returns:
        A `Path` object pointing to the checkpoint file.
    """
    return Path(tempfile.gettempdir()) / f"{prefix}_{os.getpid()}.json"

get_checkpoint_files(prefix)

Retrieves all checkpoint files matching the given prefix.

Parameters:

Name Type Description Default
prefix str

The filename prefix used to locate checkpoint files.

required

Returns:

Type Description
List[Path]

A list of Path objects for all matching checkpoint files.

Source code in src/codablellm/core/utils.py
def get_checkpoint_files(prefix: str) -> List[Path]:
    """
    Retrieves all checkpoint files matching the given prefix.

    Parameters:
        prefix: The filename prefix used to locate checkpoint files.

    Returns:
        A list of `Path` objects for all matching checkpoint files.
    """
    return list(Path(tempfile.gettempdir()).glob(f"{prefix}_*"))

get_readable_file_size(size)

Converts number of bytes to a human readable output (i.e. bytes, KB, MB, GB, TB.)

Parameters:

Name Type Description Default
size int

The number of bytes.

required

Returns:

Type Description
str

A human readable output of the number of bytes.

Source code in src/codablellm/core/utils.py
def get_readable_file_size(size: int) -> str:
    """
    Converts number of bytes to a human readable output (i.e. bytes, KB, MB, GB, TB.)

    Parameters:
        size: The number of bytes.

    Returns:
        A human readable output of the number of bytes.
    """
    kb = round(size / 2**10, 3)
    mb = round(size / 2**20, 3)
    gb = round(size / 2**30, 3)
    tb = round(size / 2**40, 3)

    for measurement, suffix in [(tb, "TB"), (gb, "GB"), (mb, "MB"), (kb, "KB")]:
        if measurement >= 1:
            return f"{measurement} {suffix}"
    return f"{size} bytes"

is_binary(file_path)

Checks if a file is a binary file.

Parameters:

Name Type Description Default
file_path PathLike

Path to a potential binary file.

required

Returns:

Type Description
bool

True if the file is a binary.

Source code in src/codablellm/core/utils.py
def is_binary(file_path: PathLike) -> bool:
    """
    Checks if a file is a binary file.

    Parameters:
        file_path: Path to a potential binary file.

    Returns:
        True if the file is a binary.
    """
    file_path = Path(file_path)
    if file_path.is_file():
        with open(file_path, "rb") as file:
            # Read the first 1KB of the file and check for a null byte or non-printable characters
            chunk = file.read(1024)
            return b"\0" in chunk or any(byte > 127 for byte in chunk)
    return False

iter_queue(queue)

Iterates over all items in a queue until it is empty.

Parameters:

Name Type Description Default
queue Queue[T]

A Queue object containing items to iterate over.

required

Returns:

Type Description
None

A generator that yields each item from the queue.

Source code in src/codablellm/core/utils.py
def iter_queue(queue: Queue[T]) -> Generator[T, None, None]:
    """
    Iterates over all items in a queue until it is empty.

    Parameters:
        queue: A `Queue` object containing items to iterate over.

    Returns:
        A generator that yields each item from the queue.
    """
    while not queue.empty():
        yield queue.get()

load_checkpoint_data(prefix, delete_on_load=False)

Loads checkpoint data from all checkpoint files matching the given prefix.

The function reads and aggregates JSON data from each checkpoint file and optionally deletes the checkpoint files after loading.

Parameters:

Name Type Description Default
prefix str

The filename prefix used to locate checkpoint files.

required
delete_on_load bool

If True, deletes the checkpoint files after loading their contents.

False

Returns:

Type Description
List[JSONObject]

A list of JSON objects aggregated from all matching checkpoint files.

Source code in src/codablellm/core/utils.py
def load_checkpoint_data(prefix: str, delete_on_load: bool = False) -> List[JSONObject]:
    """
    Loads checkpoint data from all checkpoint files matching the given prefix.

    The function reads and aggregates JSON data from each checkpoint file and optionally
    deletes the checkpoint files after loading.

    Parameters:
        prefix: The filename prefix used to locate checkpoint files.
        delete_on_load: If `True`, deletes the checkpoint files after loading their contents.

    Returns:
        A list of JSON objects aggregated from all matching checkpoint files.
    """
    checkpoint_data: List[JSONObject] = []
    checkpoint_files = get_checkpoint_files(prefix)
    for checkpoint_file in checkpoint_files:
        logger.debug(f'Loading checkpoint data from "{checkpoint_file.name}"')
        checkpoint_data.extend(json.loads(checkpoint_file.read_text()))
        if delete_on_load:
            logger.debug(f'Removing checkpoint file "{checkpoint_file.name}"')
            checkpoint_file.unlink(missing_ok=True)
    return checkpoint_data

requires_extra(extra, feature, module)

Decorator that enforces the presence of an optional dependency (extra) before executing a function.

If the required module is not installed, raises an ExtraNotInstalled error with instructions on how to install the missing extra.

Parameters:

Name Type Description Default
extra str

The name of the extra (e.g., "excel") required for the feature.

required
feature str

A description of the feature that requires the extra.

required
module str

The module name to attempt to import.

required

Returns:

Type Description
Callable[[Callable[..., Any]], Callable[..., Any]]

A decorator that checks for the required extra before calling the function.

Raises:

Type Description
ExtraNotInstalled

If the required module is not installed.

Source code in src/codablellm/core/utils.py
def requires_extra(
    extra: str, feature: str, module: str
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """
    Decorator that enforces the presence of an optional dependency (extra) before executing a function.

    If the required module is not installed, raises an `ExtraNotInstalled` error with instructions
    on how to install the missing extra.

    Parameters:
        extra: The name of the extra (e.g., "excel") required for the feature.
        feature: A description of the feature that requires the extra.
        module: The module name to attempt to import.

    Returns:
        A decorator that checks for the required extra before calling the function.

    Raises:
        ExtraNotInstalled: If the required module is not installed.
    """

    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            try:
                importlib.import_module(module)
            except ImportError as e:
                raise ExtraNotInstalled(
                    extra,
                    f'{feature} requires the "{extra}" extra to be installed. '
                    f'Install with "pip install codablellm[{extra}]"',
                ) from e
            return func(*args, **kwargs)

        return wrapper

    return decorator

resolve_kwargs(**kwargs)

Filters out keyword arguments with None values.

Returns a dictionary containing only key-value pairs where the value is not None.

Parameters:

Name Type Description Default
**kwargs Any

Arbitrary keyword arguments.

{}

Returns:

Type Description
Dict[str, Any]

A dictionary of keyword arguments with None values removed.

Source code in src/codablellm/core/utils.py
def resolve_kwargs(**kwargs: Any) -> Dict[str, Any]:
    """
    Filters out keyword arguments with `None` values.

    Returns a dictionary containing only key-value pairs where the value is not `None`.

    Parameters:
        **kwargs: Arbitrary keyword arguments.

    Returns:
        A dictionary of keyword arguments with `None` values removed.
    """
    return {k: v for k, v in kwargs.items() if v is not None}

save_checkpoint_file(prefix, contents)

Saves checkpoint data to a file based on the given prefix.

The contents are converted to JSON and written to a checkpoint file named {prefix}_{pid}.json in the system temporary directory.

Parameters:

Name Type Description Default
prefix str

The filename prefix for the checkpoint file.

required
contents Iterable[SupportsJSON]

An iterable of objects that support JSON serialization via to_json().

required
Source code in src/codablellm/core/utils.py
def save_checkpoint_file(prefix: str, contents: Iterable[SupportsJSON]) -> None:
    """
    Saves checkpoint data to a file based on the given prefix.

    The contents are converted to JSON and written to a checkpoint file named
    `{prefix}_{pid}.json` in the system temporary directory.

    Parameters:
        prefix: The filename prefix for the checkpoint file.
        contents: An iterable of objects that support JSON serialization via `to_json()`.
    """
    checkpoint_file = get_checkpoint_file(prefix)
    checkpoint_file.write_text(json.dumps([c.to_json() for c in contents]))