Skip to content

numerical_if_statement

NumericalIfStatementLayer ¤

NumericalIfStatementLayer(
    condition_operator,
    value_to_compare=None,
    result_if_true=None,
    result_if_false=None,
    name=None,
    input_dtype=None,
    output_dtype=None,
    **kwargs
)

Bases: BaseLayer

Performs a numerical if statement on the input tensor, returning a tensor of the same shape as the input tensor.

The condition operator can be one of the following: - "eq": Equal to - "neq": Not equal to - "lt": Less than - "le": Less than or equal to - "gt": Greater than - "ge": Greater than or equal to

The value to compare must be a float. We will cast the input tensor to a float if it is not already a float.

If the condition is true, the result is the result_if_true value. If the condition is false, the result is the result_if_false value.

If any of [value_to_compare, result_if_true, result_if_false] are None, we assume they are passed in as inputs to the layer in the above order. If all of them are not None, then inputs is expected to be a tensor.

Initialises the NumericalIfStatementLayer layer.

Parameters:

Name Type Description Default
condition_operator str

Operator to use in the if statement. Can be one of: - "eq": Equal to - "neq": Not equal to - "lt": Less than - "leq": Less than or equal to - "gt": Greater than - "geq": Greater than or equal to

required
value_to_compare Optional[float]

Float value to compare the input tensor to. If None, we assume it is passed in as an input to the layer.

None
result_if_true Optional[float]

Float value to return if the condition is true. If None, we assume it is passed in as an input to the layer.

None
result_if_false Optional[float]

Float value to return if the condition is false. If None, we assume it is passed in as an input to the layer.

None
name Optional[str]

The name of the layer. Defaults to None.

None
Source code in src/kamae/keras/core/layers/numerical_if_statement.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    condition_operator: str,
    value_to_compare: Optional[float] = None,
    result_if_true: Optional[float] = None,
    result_if_false: Optional[float] = None,
    name: Optional[str] = None,
    input_dtype: Optional[str] = None,
    output_dtype: Optional[str] = None,
    **kwargs: Any,
) -> None:
    """
    Initialises the NumericalIfStatementLayer layer.

    :param condition_operator: Operator to use in the if statement. Can be one of:
        - "eq": Equal to
        - "neq": Not equal to
        - "lt": Less than
        - "leq": Less than or equal to
        - "gt": Greater than
        - "geq": Greater than or equal to
    :param value_to_compare: Float value to compare the input tensor to. If None, we
    assume it is passed in as an input to the layer.
    :param result_if_true: Float value to return if the condition is true. If None,
    we assume it is passed in as an input to the layer.
    :param result_if_false: Float value to return if the condition is false. If
    None, we assume it is passed in as an input to the layer.
    :param name: The name of the layer. Defaults to `None`.
    """
    super().__init__(
        name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs
    )
    self.condition_operator = condition_operator
    self.value_to_compare = value_to_compare
    self.result_if_true = result_if_true
    self.result_if_false = result_if_false

compatible_dtypes property ¤

compatible_dtypes

Returns the compatible dtypes of the layer.

Returns:

Type Description
Optional[List[str]]

The compatible dtypes of the layer.

get_config ¤

get_config()

Gets the configuration of the NumericalIfStatement layer.

Specifically adds the following to the base configuration: - condition_operator - value_to_compare - result_if_true - result_if_false

Returns:

Type Description
Dict[str, Any]

Dictionary of the configuration of the layer.

Source code in src/kamae/keras/core/layers/numerical_if_statement.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def get_config(self) -> Dict[str, Any]:
    """
    Gets the configuration of the NumericalIfStatement layer.

    Specifically adds the following to the base configuration:
    - condition_operator
    - value_to_compare
    - result_if_true
    - result_if_false

    :returns: Dictionary of the configuration of the layer.
    """
    config = super().get_config()
    config.update(
        {
            "condition_operator": self.condition_operator,
            "value_to_compare": self.value_to_compare,
            "result_if_true": self.result_if_true,
            "result_if_false": self.result_if_false,
        }
    )
    return config