"""
Copyright (c) Entropica Labs Pte Ltd 2025.
Use, distribution and reproduction of this program in its source or compiled
form is prohibited without the express written consent of Entropica Labs Pte
Ltd.
"""
from __future__ import annotations
from functools import cached_property
import textwrap
from uuid import uuid4
from typing import ClassVar
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from .circuit import Circuit
from .channel import Channel, ChannelType
from .utilities import BoolOp
from .utilities.validation_tools import dataclass_config
# pylint: disable=arguments-differ, arguments-renamed, unnecessary-lambda-assignment
[docs]
@dataclass(config=dataclass_config)
class IfElseCircuit(Circuit):
"""
Branching circuit: executes if_circuit or else_circuit depending on some classical
condition circuit.
"""
# Marker to allow other modules to detect IfElseCircuit instances without
# importing the class (prevents circular imports). This is a ClassVar so
# pydantic dataclasses won't treat it as a field.
_loom_ifelse_marker: ClassVar[bool] = True
# Define class-specific fields for IfElseCircuit
if_circuit: Circuit | None = Field(
default_factory=lambda: Circuit("empty_branch"), validate_default=True
)
else_circuit: Circuit | None = Field(
default_factory=lambda: Circuit("empty_branch"), validate_default=True
)
condition_circuit: Circuit | None = Field(
default_factory=lambda: Circuit(
name=BoolOp.MATCH,
channels=[Channel(type=ChannelType.CLASSICAL)],
),
validate_default=True,
)
id: str = Field(default_factory=lambda: str(uuid4()))
# Override fields from Circuit
name: str = Field(default="if-else_circuit", init=False, frozen=True)
circuit: tuple[tuple[Circuit, ...], ...] = Field(default_factory=tuple, init=False)
channels: tuple[Channel, ...] = Field(default_factory=tuple, init=False)
# Validation functions
[docs]
@field_validator("condition_circuit")
@classmethod
def validate_condition_circuit(cls, value: Circuit) -> Circuit:
"""
Provide a default condition circuit if none is provided. Otherwise, check if the
condition circuit is classical. Throw an error if that is not the case.
"""
if value is None:
return Circuit(
BoolOp.MATCH,
channels=[Channel(type=ChannelType.CLASSICAL)],
)
if not all(channel.is_classical() for channel in value.channels):
raise ValueError(
"IfElseCircuit `condition_circuit` must be a circuit with classical "
f"channels only. Found channels: {value.channels}"
)
if value.name in BoolOp.multi_bit_list():
if len(value.channels) < 2:
raise ValueError(
f"Condition circuit with BoolOp '{value.name}' must have at least "
"two classical channels."
)
elif value.name in BoolOp.mono_bit_list():
if len(value.channels) != 1:
raise ValueError(
f"Condition circuit with BoolOp '{value.name}' must have only one "
"classical channel."
)
else:
raise ValueError(
f"Unsupported BoolOp '{value.name}' for condition circuit. Supported "
"BoolOps are: "
f"{', '.join(BoolOp.multi_bit_list() + BoolOp.mono_bit_list())}."
)
return value
[docs]
@field_validator("if_circuit", "else_circuit")
@classmethod
def validate_circuit_branches(cls, circuit: Circuit) -> Circuit:
"""
Assign default empty Circuit if None, and wrap base gates into Circuit if
needed.
"""
if circuit is None or circuit.name == "empty_branch":
return Circuit(name="empty_branch")
if not circuit.circuit:
return Circuit(name=circuit.name, circuit=circuit)
return circuit
def __post_init__(self):
"""
Post-initialization to set derived fields. As Circuit objects are immutable,
we use object.__setattr__ to set these fields. Additionally, we do not perform
validation again here as the fields are derived from already validated fields.
"""
# Format circuit field as ((if_circuit, else_circuit),)
object.__setattr__(self, "circuit", ((self.if_circuit, self.else_circuit),))
# Gather all unique channels from both branches and prepend condition channels
# Note that the order of channels is:
# condition channels, quantum channels, classical channels
branch_channels = set(self.if_circuit.channels) | set(
self.else_circuit.channels
)
typing_order = (
ChannelType.QUANTUM,
ChannelType.CLASSICAL,
)
ordered_branch_channels = tuple(
sorted(branch_channels, key=lambda x: typing_order.index(x.type))
)
all_channels = self.condition_circuit.channels + ordered_branch_channels
object.__setattr__(self, "channels", all_channels)
# Set duration to max of branches
object.__setattr__(
self, "duration", max(self.if_circuit.duration, self.else_circuit.duration)
)
# Inherited parent methods
# def nr_of_qubits_in_circuit(self) -> int:
# """This method is inherited from Circuit"""
# def from_circuits(cls):
# """This method is inherited from Circuit"""
# Override Parent Methods
[docs]
@classmethod
def as_gate(cls):
"""Represent IfElseCircuit as a gate."""
raise NotImplementedError("IfElseCircuit cannot be represented as a gate.")
[docs]
def circuit_seq(self):
"""
Returns the sequence of sub-circuits in the circuit field.
"""
raise NotImplementedError(
"IfElseCircuit cannot be converted into a Circuit sequence."
)
[docs]
def flatten(self) -> IfElseCircuit:
"""
Flatten the IfElseCircuit by flattening its branches and condition circuit.
"""
flat_if = self.if_circuit.flatten()
flat_else = self.else_circuit.flatten()
flat_condition = self.condition_circuit.flatten()
return IfElseCircuit(
if_circuit=flat_if,
else_circuit=flat_else,
condition_circuit=flat_condition,
)
[docs]
@classmethod
def unroll(cls, input_circuit: IfElseCircuit) -> tuple[IfElseCircuit, ...]:
unrolled_if = input_circuit.if_circuit.unroll(input_circuit.if_circuit)
unrolled_else = input_circuit.else_circuit.unroll(input_circuit.else_circuit)
wrapped_unrolled_if = Circuit(
name=input_circuit.if_circuit.name, circuit=unrolled_if
)
wrapped_unrolled_else = Circuit(
name=input_circuit.else_circuit.name, circuit=unrolled_else
)
return (
IfElseCircuit(
if_circuit=wrapped_unrolled_if,
else_circuit=wrapped_unrolled_else,
condition_circuit=input_circuit.condition_circuit,
),
)
@cached_property
def is_condition_single_bit(self) -> bool:
"""Check if the condition circuit is a single-bit condition."""
return self.condition_circuit.name in BoolOp.mono_bit_list()
@cached_property
def is_single_gate_conditioned(self) -> bool:
"""Whether this is just a single gate conditioned by a classical condition."""
return (
len(self.if_circuit.circuit) == 1
and len(self.if_circuit.circuit[0]) == 1
and not self.if_circuit.circuit[0][0].circuit
and self.else_circuit.name == "empty_branch"
)
def __eq__(self, other: IfElseCircuit) -> bool:
"""Check equality between two IfElseCircuit instances."""
if not isinstance(other, IfElseCircuit):
return False
return (
self.if_circuit == other.if_circuit
and self.else_circuit == other.else_circuit
and self.condition_circuit == other.condition_circuit
)
def __repr__(self) -> str:
"""Return a concise string representation of the circuit."""
return (
f"{self.name}\n"
f" if: {self.if_circuit.name}\n"
f" else: {self.else_circuit.name}\n"
f" condition: {self.condition_circuit.name}"
)
[docs]
@staticmethod
def construct_padded_circuit_time_sequence():
"""Construct a padded circuit time sequence."""
raise NotImplementedError(
"IfElseCircuit cannot construct a padded circuit time sequence."
)
[docs]
def detailed_str(self) -> str:
"""Return a detailed string representation of the circuit."""
_skip_firstline = lambda s: s.detailed_str().splitlines()[1:]
if_str = textwrap.indent("\n".join(_skip_firstline(self.if_circuit)), " ")
else_str = textwrap.indent(
"\n".join(_skip_firstline(self.else_circuit)), " "
)
return (
f"{self.name}\n"
f" if: {self.if_circuit.name}\n"
f"{if_str}\n"
f" else: {self.else_circuit.name}\n"
f"{else_str}\n"
f" condition: {self.condition_circuit.name}\n"
)