from pathlib import Path
import polars as pl
import rainfallqc
from rainfallqc.qc_frameworks.inbuilt_qc_frameworks import NON_ROWWISE_QC_CHECKS, NON_ROWWISE_QC_CONVERTER
from rainfall_gridder.quality_control.apply_intenseQC_rulebase import apply_intenseQC_rulebase
from rainfall_gridder.quality_control.nearby_rainfall_data_loader import NearbyRainfallDataLoader
from rainfall_gridder.utils import spatial_utils
time_res_to_n_time_steps_in_day = {"15m": 96, "1h": 24}
[docs]
class QualityController:
"""
Main quality control running algorithm.
"""
def __init__(
self,
rainfall_data: pl.DataFrame,
rainfall_metadata: pl.DataFrame,
station_id_col: str,
station_name_col: str,
date_time_col: str,
precipitation_col: str,
easting_col: str,
northing_col: str,
start_date_col: str,
end_date_col: str,
input_crs: str,
min_n_timesteps: int,
output_dir: str | Path,
time_res: str,
smallest_rainfall_amount: int | float,
min_n_neighbours: int,
qc_framework: str,
nearby_rainfall_data_loader_kwargs: dict = {},
verbose: bool = False,
):
"""
Quality control part of gridded workflow.
Parameters
----------
rainfall_data:
Rainfall gauge data
rainfall_metadata:
Details of rain gauge data
input_crs:
Projection of the east_west and north_south cols of the input data
output_dir:
Output directory for data files
min_n_timesteps:
Minimum number of timesteps needed in rainfall_data to be considered valid
time_res:
Resolution of data (i.e. hourly or 15 min denoted: '1h' or '15m')
smallest_rainfall_amount:
Smallest measurable rainfall amount
min_n_neighbours:
Minimum number of nearby rain gauges allowed for neighbourhood QC checks.
qc_framework:
QC framework to run (see rainfallqc.qc_frameworks/inbuilt_qc_frameworks for options or build your own by looking at RainfallQC docs)
nearby_rainfall_data_loader_kwargs:
Any additional arguments to override the defaults of the nearby data loader i.e. distance_threshold and n_closest (default is {})
verbose:
Whether to print progress as algorithm is run (default: False)
"""
self.station_id_col = station_id_col
self.station_name_col = station_name_col
self.date_time_col = date_time_col
self.precipitation_col = precipitation_col
self.easting_col = easting_col
self.northing_col = northing_col
self.start_date_col = start_date_col
self.end_date_col = end_date_col
self.input_crs = self._validate_input_crs(input_crs)
self.min_n_timesteps = min_n_timesteps
self.time_res = self._validate_time_res(time_res)
self.smallest_rainfall_amount = smallest_rainfall_amount
self.min_n_neighbours = min_n_neighbours
self.qc_framework = qc_framework
self.nearby_rainfall_data_loader_kwargs = nearby_rainfall_data_loader_kwargs
self.output_dir = output_dir
self.verbose = verbose
if self.qc_framework == "intenseqc_rulebase_only":
self.qc_kwargs, self.qc_methods_to_run = self._set_up_intenseqc_framework()
else:
raise ValueError(
f"QC framework: '{self.qc_framework}' not recognised, please select from: 'intenseqc_rulebase_only'"
)
if "latitude" not in rainfall_metadata.columns or "longitude" not in rainfall_metadata.columns:
self.rainfall_metadata = self._add_latlon_to_rainfall_metadata(rainfall_metadata)
else:
self.rainfall_metadata = rainfall_metadata
self.rainfall_data = rainfall_data
def _add_latlon_to_rainfall_metadata(self, rainfall_metadata: pl.DataFrame) -> pl.DataFrame:
return spatial_utils.crs_to_crs(
rainfall_metadata,
crs_in=self.input_crs,
crs_out="EPSG:4326",
east_west_col_in=self.easting_col,
north_south_col_in=self.northing_col,
east_west_col_out="longitude",
north_south_col_out="latitude",
)
def _set_up_intenseqc_framework(self) -> tuple[dict, list]:
qc_kwargs = {
"QC2": {"k": 10},
"shared": {
"time_res": self.time_res,
"smallest_measurable_rainfall_amount": self.smallest_rainfall_amount,
"wet_threshold": 1.0,
"min_n_neighbours": self.min_n_neighbours,
"n_neighbours_ignored": 0,
"accumulation_multiplying_factor": 2.0,
},
}
qc_methods_to_run = ["QC2", "QC10", "QC11", "QC12", "QC13", "QC14", "QC15", "QC17", "QC19", "QC20"]
return qc_kwargs, qc_methods_to_run
def _validate_input_crs(self, input_crs: str) -> str:
assert input_crs.startswith("EPSG:"), (
f"Invalid input_crs {input_crs}, needs to begin with 'EPSG:' like 'EPSG:4326'."
)
return input_crs
def _validate_time_res(self, time_res: str) -> str:
assert time_res in time_res_to_n_time_steps_in_day.keys(), (
f"'{time_res}' not in accepted time resolutions for data. Accepted time res: {time_res_to_n_time_steps_in_day.keys()}"
)
return time_res
[docs]
@classmethod
def run(
cls, save_data: bool, return_data: bool, partition_by_columns: list = None, **kwargs
) -> None | tuple[pl.DataFrame, pl.DataFrame]:
"""
Run the quality controller and return and/or save the prepared data.
Parameters
----------
save_data:
Whether to save data to output directory
return_data:
Whether to return dataframes
partition_by_columns:
List of columns to partition the parquet files by if saving outputs
Returns
-------
qc_data:
Data run through algorithm
qc_metadata:
Metadata of data run through algorithm
"""
quality_controller = cls(**kwargs)
if quality_controller.verbose:
print("Quality controlling data for gridder")
quality_controller.quality_control_data()
if save_data:
if quality_controller.verbose:
print(f"Saving data to {quality_controller.output_dir}")
quality_controller.save_qcd_data(partition_by_columns)
quality_controller.save_qcd_metadata()
quality_controller.save_summary_of_qc()
quality_controller.save_qc_rulebase_summary()
else:
if quality_controller.verbose:
print("Data not saved")
if return_data:
return (
quality_controller.qcd_data,
quality_controller.qcd_metadata,
quality_controller.summary_of_qc,
quality_controller.qc_rulebase_summary,
)
[docs]
def get_nearest_neighbour(self, nearby_rainfall_data_loader, station_id):
if len(nearby_rainfall_data_loader.nearby_rain_gauge_distances) > 0:
return nearby_rainfall_data_loader.nearby_rain_gauge_distances.sort("distance")[0][
self.station_id_col
].item()
else:
if self.verbose:
print(f"Station ID: {station_id} has no neighbours\n")
return False
[docs]
def quality_control_data(self):
# preallocate the output lists
unique_station_ids = self.rainfall_metadata[self.station_id_col].unique()
overall_summary_of_qc = [None] * len(unique_station_ids)
qcd_data_list = [None] * len(unique_station_ids)
rulebase_summary = [None] * len(unique_station_ids)
# begin loop
for ind, station_id in enumerate(unique_station_ids):
nearby_rainfall_data_loader = NearbyRainfallDataLoader(
metadata=self.rainfall_metadata,
station_id=station_id,
date_time_col=self.date_time_col,
precipitation_col=self.precipitation_col,
station_id_col=self.station_id_col,
start_date_col=self.start_date_col,
end_date_col=self.end_date_col,
min_overlap_days=self.min_n_timesteps / time_res_to_n_time_steps_in_day[self.time_res],
rainfall_data_source="df",
rainfall_data_pl=self.rainfall_data,
time_res=self.time_res,
**self.nearby_rainfall_data_loader_kwargs,
)
# Check if that station actually has any neighbours
if nearby_rainfall_data_loader.nearest_station_id is None:
if self.verbose:
print(f"Station ID: {nearby_rainfall_data_loader.station_id} has no neighbours\n")
continue
nearby_metadata = nearby_rainfall_data_loader.nearby_metadata
nearby_rainfall_data = nearby_rainfall_data_loader.nearby_rainfall_for_rainfallqc
# Update shared QC kwargs with latest values from nearby gauge loader
self.update_shared_qc_kwargs(nearby_rainfall_data_loader)
# Run QC framework
try:
assert len(nearby_rainfall_data) > self.min_n_timesteps, (
f"Data needs at least {self.min_n_timesteps} timesteps"
)
qc_result = rainfallqc.apply_qc_framework.run_qc_framework(
data=nearby_rainfall_data,
qc_framework=self.qc_framework,
qc_methods_to_run=self.qc_methods_to_run,
qc_kwargs=self.qc_kwargs,
)
except Exception as e:
if self.verbose:
print(station_id, e, "\n")
continue
# Summarise QC flags into statistics
qc_summariser = QCSummariser(
station_id=station_id,
rainfall_data=nearby_rainfall_data,
nearby_metadata=nearby_metadata,
qc_result=qc_result,
verbose=self.verbose,
)
# Apply rulebase
rule_removed_rows, n_rows_removed = apply_intenseQC_rulebase(
qc_summariser.all_flags, station_id, time_step=self.time_res
)
if self.verbose:
print(
f"Station ID: {station_id}\tA total of {qc_summariser.all_flags['all_flags_by_row'][station_id].count() - rule_removed_rows[station_id].count()} rows were removed"
) # some rows may have stayed null
## get back into parquet format that fits with Oracle
rule_removed_rows = rule_removed_rows.select(["time", station_id]) # saves memory
rule_removed_rows = rule_removed_rows.with_columns(pl.lit(station_id).alias(self.station_id_col))
rule_removed_rows = rule_removed_rows.rename(
{station_id: self.precipitation_col, "time": self.date_time_col}
)
# Append summaries to lists
overall_summary_of_qc[ind] = qc_summariser.summary_of_qc
qcd_data_list[ind] = rule_removed_rows
rulebase_summary[ind] = n_rows_removed
if self.verbose:
print("")
# Add summaries and qc data to self
self.qc_rulebase_summary = pl.DataFrame(rulebase_summary)
self.summary_of_qc = pl.DataFrame(overall_summary_of_qc)
self.qcd_data = pl.concat([qcd_data for qcd_data in qcd_data_list if qcd_data is not None])
# double check
self.qcd_metadata = self.rainfall_metadata.filter(
pl.col(self.station_id_col)
.cast(pl.String)
.is_in(self.summary_of_qc[self.station_id_col].drop_nulls().to_list())
)
[docs]
def save_qcd_data(self, partition_by_columns: list = None) -> None:
"""
Save data that has been quality controlled for gridding.
Parameters
----------
partition_by_columns:
Columns that decide the partitioning of the output parquet file structure (default is station_id_col)
"""
if partition_by_columns is None:
partition_by_columns = [self.station_id_col]
if self.qcd_data is None:
raise RuntimeError("You must call quality_control_data() before save_qcd_data()")
assert len(self.qcd_metadata.filter(pl.col("file_path").is_duplicated())) == 0, (
"Problem with metadata as duplicate filepaths"
)
# Save partitioned parquet file
(
self.qcd_data.sort(self.date_time_col).write_parquet(
self.output_dir / "qc_data",
partition_by=partition_by_columns,
)
)
if self.verbose:
print(f"QC'd rainfall data available at: {self.output_dir / 'qc_data/'}")
[docs]
def save_summary_of_qc(self) -> None:
if self.summary_of_qc is None:
raise RuntimeError("You must call quality_control_data() before summary_of_qc()")
self.summary_of_qc.write_parquet(self.output_dir / "summary_of_qc.parquet")
if self.verbose:
print(f"Summary of QC available at: {self.output_dir / 'summary_of_qc.parquet'}")
[docs]
def save_qc_rulebase_summary(self) -> None:
if self.qc_rulebase_summary is None:
raise RuntimeError("You must call quality_control_data() before save_qc_rulebase_summary()")
self.qc_rulebase_summary.write_parquet(self.output_dir / "qc_rulebase_summary.parquet")
if self.verbose:
print(f"Summary of QC rulebase available at: {self.output_dir / 'qc_rulebase_summary.parquet'}")
[docs]
def update_shared_qc_kwargs(self, nearby_rainfall_data_loader: NearbyRainfallDataLoader) -> None:
"""
Update all the shared keyword arguments.
TODO: Check this updating in the loop properly.
"""
self.qc_kwargs["shared"]["rain_col"] = nearby_rainfall_data_loader.station_id
self.qc_kwargs["shared"]["target_gauge_col"] = nearby_rainfall_data_loader.station_id
self.qc_kwargs["shared"]["nearest_neighbour"] = nearby_rainfall_data_loader.nearest_station_id
self.qc_kwargs["shared"]["list_of_nearest_stations"] = nearby_rainfall_data_loader.nearby_rain_gauge_distances[
self.station_id_col
].to_list()
self.qc_kwargs["shared"]["gauge_lat"] = nearby_rainfall_data_loader.nearby_metadata.filter(
pl.col(self.station_id_col) == nearby_rainfall_data_loader.station_id
)["latitude"]
self.qc_kwargs["shared"]["gauge_lon"] = nearby_rainfall_data_loader.nearby_metadata.filter(
pl.col(self.station_id_col) == nearby_rainfall_data_loader.station_id
)["longitude"]
class QCSummariser:
"""
Summariser for QC flags.
"""
def __init__(
self,
station_id: str,
rainfall_data: pl.DataFrame,
nearby_metadata: pl.DataFrame,
qc_result: dict,
verbose: bool,
):
"""
Summarises QC from a given station.
Parameters
----------
station_id:
Target rainfall station ID
rainfall_data:
Rainfall data from target ID
nearby_metadata:
Details about rainfall data from target and a given no. of neighbours
qc_result:
Summary of QC output from rainfallqc.apply_qc_framework.run_qc_framework.
verbose:
Whether to print progress as algorithm is run (default: False)
"""
self.all_flags = {}
self.station_id = station_id
self.rainfall_data = rainfall_data
self.nearby_metadata = nearby_metadata
self.qc_result = qc_result
self.verbose = verbose
# Join QC checks into one dict
self._join_qc_result_into_all_flags()
# Get count of flagged and not-flagged rows
flagged_rows, not_flagged_rows = self._count_flagged_and_non_flagged_rows()
# Create summary
self.all_flags, self.summary_of_qc = self._create_summary_of_qc(flagged_rows, not_flagged_rows=not_flagged_rows)
def _join_qc_result_into_all_flags(self):
self.all_flags["all_flags_by_row"] = self.rainfall_data["time", self.station_id]
for qc in self.qc_result:
if isinstance(self.qc_result[qc], pl.DataFrame):
try:
self.all_flags["all_flags_by_row"] = self.all_flags["all_flags_by_row"].join(
self.qc_result[qc], on="time"
)
except Exception as e:
if self.verbose:
print(e, self.station_id)
else:
self.all_flags[qc] = self.qc_result[qc]
def _count_flagged_and_non_flagged_rows(self) -> tuple[pl.Series, pl.Series]:
"""
Count number of rows of rainfall data with flags or not.
"""
self.all_flags["all_flags_by_row"] = self.all_flags["all_flags_by_row"].with_columns(
pl.when(
pl.any_horizontal(
pl.all().exclude(["time", self.station_id]).fill_null(0.0).map_elements(lambda col: col > 0)
)
)
.then(1)
.otherwise(0)
.alias("is_flagged")
)
# Count number of flagged rows
flagged_rows = (
self.all_flags["all_flags_by_row"]["is_flagged"].value_counts().filter(pl.col("is_flagged") == 1)["count"]
)
not_flagged_rows = (
self.all_flags["all_flags_by_row"]["is_flagged"].value_counts().filter(pl.col("is_flagged") == 0)["count"]
)
return flagged_rows, not_flagged_rows
def _create_summary_of_qc(self, flagged_rows: pl.Series, not_flagged_rows: pl.Series):
"""
Create summary of the rows flagged in rainfall data.
"""
total_rows = flagged_rows + not_flagged_rows
perc_flagged = (flagged_rows / total_rows) * 100
perc_flagged = perc_flagged.item() if perc_flagged.len() == 1 else 0
if self.verbose:
print(f"Station ID: {self.station_id}\t\tFlag rate: {perc_flagged: .2f}%")
# add to overall QC summary
summary_of_qc = {}
summary_of_qc["station_id"] = self.station_id
summary_of_qc["num_nearby_gauges"] = len(self.nearby_metadata) - 1 # do not count the target
summary_of_qc["perc_flagged"] = round(perc_flagged, 3)
summary_of_qc["total_flagged_rows"] = flagged_rows[0] if flagged_rows.len() > 0 else 0
summary_of_qc["total_rows"] = len(self.all_flags["all_flags_by_row"])
for qc_key in NON_ROWWISE_QC_CHECKS:
if qc_key not in self.all_flags:
continue
if isinstance(self.all_flags[qc_key], list):
# sum number of years flagged
summary_of_qc[NON_ROWWISE_QC_CONVERTER[qc_key]] = sum(item != 0 for item in self.all_flags[qc_key])
else:
summary_of_qc[NON_ROWWISE_QC_CONVERTER[qc_key]] = self.all_flags[qc_key]
for col in self.all_flags["all_flags_by_row"].columns[2:]:
summary_of_qc[col] = len(self.all_flags["all_flags_by_row"].filter(pl.col(col) > 0).drop_nans()[col])
return self.all_flags, summary_of_qc