Skip to content

shape_utils

Multi-backend shape utility functions for backend-agnostic operations.

reshape_to_equal_rank ยค

reshape_to_equal_rank(inputs)

Reshapes the input tensors to match the rank of the largest tensor.

This is a backend-agnostic version using keras.ops.

Parameters:

Name Type Description Default
inputs Iterable[KerasTensor]

The input tensors to reshape.

required

Returns:

Type Description
List[KerasTensor]

The reshaped input tensors.

Source code in src/kamae/keras/core/utils/shape_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def reshape_to_equal_rank(inputs: Iterable[KerasTensor]) -> List[KerasTensor]:
    """
    Reshapes the input tensors to match the rank of the largest tensor.

    This is a backend-agnostic version using keras.ops.

    :param inputs: The input tensors to reshape.
    :return: The reshaped input tensors.
    """
    max_rank = max([len(tensor.shape) for tensor in inputs])
    reshaped_inputs = []
    for x in inputs:
        rank_diff = max_rank - len(x.shape)
        if rank_diff > 0:
            # Get shape as tensor (handles both static and dynamic shapes)
            shape_tensor = ops.convert_to_tensor(ops.shape(x))
            reshape_dim = ops.concatenate(
                [
                    shape_tensor[:-1],
                    ops.ones(rank_diff, dtype="int32"),
                    shape_tensor[-1:],
                ],
                axis=0,
            )
            x = ops.reshape(x, reshape_dim)
        reshaped_inputs.append(x)
    return reshaped_inputs