"""
Copyright 2024 Entropica Labs Pte Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .pauli_frame import PauliFrame
from .classicalreg import ClassicalRegister
[docs]
class DataStore:
"""
An Object that records relevant information users might want recorded from
a particular computation.
"""
def __init__(self):
self.measurements = {"time_step": []}
self.pf_records = {"forward": {"time_step": []}, "backward": {"time_step": []}}
self.cr_records = {"time_step": []}
def __str__(self):
return str(self.measurements)
[docs]
def set_time_step(self, time_step: int) -> None:
"""
Sets the time step of the current Moment being processed.
"""
self.time_step = time_step # pylint: disable=attribute-defined-outside-init
[docs]
def record_measurements(
self,
measurement_id: str,
measurement_result: int,
is_result_random: int,
) -> None:
"""
Records the measurement result in the DataStore. Raises a ValueError if
no time step was found in the Moment.
"""
if self.time_step is not None:
meas_result = {
"measurement_result": measurement_result,
"is_random": is_result_random,
}
# If the measurements of a particular time step has not been
# recorded. Record the current measurement results in that
# time step.
if self.time_step not in self.measurements["time_step"]:
self.measurements["time_step"].append(self.time_step)
self.measurements.update({str(self.time_step): {}})
self.measurements[str(self.time_step)].update({measurement_id: meas_result})
else:
raise ValueError("The time step of the Moment cannot be None.")
[docs]
def record_pauli_frame(
self,
init_pauliframe: PauliFrame,
recorded_pauliframe: PauliFrame,
) -> None:
"""
Records the PauliFrame in the DataStore. Raises a ValueError if
no time step was found in the Moment.
"""
if self.time_step is not None:
direction = init_pauliframe.direction
pf_id = init_pauliframe.id
record_pf = {
"initial_pauli_frame": init_pauliframe,
"recorded_pauli_frame": recorded_pauliframe,
}
if self.time_step not in self.pf_records[direction]["time_step"]:
self.pf_records[direction]["time_step"].append(self.time_step)
self.pf_records[direction].update({str(self.time_step): {}})
self.pf_records[direction][str(self.time_step)].update({pf_id: record_pf})
else:
raise ValueError("The time step of the Moment cannot be None.")
[docs]
def record_classical_register(self, input_classical_reg: ClassicalRegister) -> None:
"""
Records the Snapshot of the Classical Register into the DataStore.
"""
if self.time_step is not None:
cr_snapshot = {
input_classical_reg.name: input_classical_reg.create_snapshot()
}
# Records the state of a classical register at a particular time step.
if self.time_step not in self.cr_records["time_step"]:
self.cr_records["time_step"].append(self.time_step)
self.cr_records.update({str(self.time_step): {}})
self.cr_records[str(self.time_step)].update(cr_snapshot)
else:
raise ValueError("The time step of the Moment cannot be None.")
[docs]
def record_measurement_from_pauliframes(
self,
measurement_id: str,
results: list[int],
init_pauliframes: list[PauliFrame],
) -> None:
"""
Records measurement flips that result from the interaction of
Pauli frames and measurement operations.
Measurement flips is passed as a list as there could be multiple Pframes.
"""
if self.time_step is not None:
# first we check whether "time_step" is already written into datastore
# if not, then we update the dict.
if self.time_step not in self.measurements["time_step"]:
self.measurements["time_step"].append(self.time_step)
self.measurements.update({str(self.time_step): {}})
self.measurements[str(self.time_step)].update(
{
measurement_id: {
"flip_results": {},
}
}
)
# if "time_step" already is present, then we open a new dict keyed by
# "flip_results" for particular measurement_id. If "measurement_id" has not
# been previously processed update with key.
elif measurement_id not in self.measurements[str(self.time_step)]:
self.measurements[str(self.time_step)].update(
{
measurement_id: {
"flip_results": {},
}
}
)
elif (
"flip_results"
not in self.measurements[str(self.time_step)][measurement_id]
):
self.measurements[str(self.time_step)][measurement_id].update(
{"flip_results": {}}
)
for pframe, result in zip(init_pauliframes, results, strict=True):
self.measurements[str(self.time_step)][measurement_id][
"flip_results"
].update({pframe.id: result})
else:
raise ValueError("The time step of the Moment cannot be None.")
[docs]
def get_pframes(
self,
temporal_direction: str = "forward",
) -> tuple[dict]:
"""
A helper function to return a tuple of all Pframes (dict)
`temporal_direction` can be either "forward" or "backward".
returns:
tuple: A tuple of dictionaries, where each dictionary
corresponds to a PauliFrame that was created in the circuit,
and the computed output --- for frames propagated forwards.
"""
raw_data = self.pf_records[temporal_direction]
raw_data_keys = raw_data["time_step"]
return tuple(fr for key in raw_data_keys for fr in raw_data[str(key)].values())
[docs]
def get_pframe_measurements(
self,
pauliframe_id: str,
) -> dict:
"""
Helper function to aggregate all measurement results
belonging to the same Pframe. The keys of returned dict
are the uuids of meausurements in the circuit.
returns:
dict: A dictionary where keys are measurement ids and values are
the PauliFrame flip results for the given Pframe id.
"""
result_data = {}
raw_data_timekeys = self.measurements["time_step"]
for ts in raw_data_timekeys:
for m_id in self.measurements[str(ts)].keys():
measurement_entry = self.measurements[str(ts)][m_id]
if "flip_results" in measurement_entry:
for pf_id_sel, m_val in measurement_entry["flip_results"].items():
if pf_id_sel == pauliframe_id:
result_data.update({m_id: m_val})
return result_data