Source code for rainfall_gridder.prepare_data.gauge_grid_correlator

from pathlib import Path
import polars as pl
import scipy.stats
import xarray as xr

from rainfall_gridder.prepare_data.data_combiner import GaugeVsGriddedRainfallMatcher
from rainfall_gridder.utils import spatial_utils, xarray_utils


[docs] class GaugeVsGriddedCorrelator: def __init__( self, gauge_data: pl.DataFrame, gauge_metadata: pl.DataFrame, nearest_gridded_daily: xr.Dataset, station_id: str, precipitation_col: str, gridded_rainfall_col: str, date_time_col: str, start_date_col: str, end_date_col: str, station_id_col: str, easting_col: str, northing_col: str, rainfall_offset_hours: int, aggregate_gauge_to_daily: bool = True, ): """ Will correlate Rain gauge data with nearest gridded data. TODO: make sure the combining of gauge name is done in order i.e. 1-2 not 2-1 Parameters ---------- gauge_data: Rainfall gauge data gauge_metadata: Details of rain gauge data """ # filter to the single station ID self.gauge_data = gauge_data.filter(pl.col(station_id_col) == station_id).sort(by=date_time_col) self.gauge_metadata = gauge_metadata.filter(pl.col(station_id_col) == station_id) self.station_id = station_id self.precipitation_col = precipitation_col self.gridded_rainfall_col = gridded_rainfall_col self.date_time_col = date_time_col self.start_date_col = start_date_col self.end_date_col = end_date_col self.station_id_col = station_id_col self.easting_col = easting_col self.northing_col = northing_col self.rainfall_offset_hours = rainfall_offset_hours self.nearest_gridded_daily = self._load_daily_nearest_gridded_daily(nearest_gridded_daily) if aggregate_gauge_to_daily: self.gauge_data = self._aggregate_gauge_subdaily_to_daily() self.combined_data = self._join_gauge_to_grid() def _aggregate_gauge_subdaily_to_daily(self) -> pl.DataFrame: return ( self.gauge_data.drop_nulls() .group_by_dynamic( self.date_time_col, every="1d", offset=f"{self.rainfall_offset_hours}h", label="left", ) .agg(pl.col(self.precipitation_col).sum()) ) def _load_daily_nearest_gridded_daily(self, nearest_gridded_daily: xr.Dataset) -> xr.Dataset: """ Load daily gridded data. TODO: This method need some work, because what if the day doesn't start at 00:00. TODO: also need to assert the data is actually daily """ nearest_gridded_daily = xarray_utils.replace_daily_time_step_hour_with_zero( nearest_gridded_daily, time_col="time" ) nearest_gridded_daily = self._subset_gridded_data_to_start_and_end_of_gauge(nearest_gridded_daily) return spatial_utils.get_nearest_grid_cell( nearest_gridded_daily, easting=self.gauge_metadata[self.easting_col][0], northing=self.gauge_metadata[self.northing_col][0], ) def _join_gauge_to_grid(self): s_date = self.gauge_metadata[self.start_date_col][0] e_date = self.gauge_metadata[self.end_date_col][0] gauge_gridded_matcher = GaugeVsGriddedRainfallMatcher( [self.precipitation_col], output_col_name="", rainfall_offset_hours=self.rainfall_offset_hours, date_time_col=self.date_time_col, ) nearest_gridded_daily_cell_df = gauge_gridded_matcher.prepare_gridded_daily( self.nearest_gridded_daily, s_date=s_date, e_date=e_date, rain_col=self.gridded_rainfall_col, ) combined_gauge_gridded = gauge_gridded_matcher.join_daily_gauge_and_gridded( self.gauge_data, nearest_gridded_daily_cell_df ) return combined_gauge_gridded def _subset_gridded_data_to_start_and_end_of_gauge(self, nearest_gridded_daily: xr.Dataset) -> xr.Dataset: """ Clip gridded data so only extends between start and end date of gauge data. Parameters ---------- nearest_gridded_daily: Nearest grid cell of daily rainfall Returns ------- nearest_gridded_daily_clipped: Nearest grid cell of daily rainfall clipped to start and end datetime of gauge Raises ------ ValueError: If there is no overlap between gauge data and gridded daily rainfall """ start_datetime = self.gauge_metadata[self.start_date_col][0] end_datetime = self.gauge_metadata[self.end_date_col][0] nearest_gridded_daily_clipped = nearest_gridded_daily.sel(time=slice(start_datetime, end_datetime)) if nearest_gridded_daily_clipped["time"].size == 0: raise ValueError( "No overlap between the daily gridded data and the inputted gauge data. " f"Gauge data runs from {start_datetime} to {end_datetime}, " f"whereas the gridded data runs from " f"{nearest_gridded_daily['time'].min().data} " f"to {nearest_gridded_daily['time'].max().data}." ) return nearest_gridded_daily_clipped
[docs] def get_corr(self): r_result = scipy.stats.pearsonr( self.combined_data[self.precipitation_col], self.combined_data[self.gridded_rainfall_col], ).statistic rho_result = scipy.stats.spearmanr( self.combined_data[self.precipitation_col], self.combined_data[self.gridded_rainfall_col], ).statistic return r_result, rho_result
class BatchGaugeVsGriddedCorrelator(GaugeVsGriddedCorrelator): """ Run the correlator on a list of station IDs and return and/or save the prepared data. """ def __init__( self, gauge_data: pl.DataFrame, gauge_metadata: pl.DataFrame, gridded_rainfall_data: xr.Dataset, gridded_rainfall_col: str, station_ids_to_correlate: list, station_id_col: str, precipitation_col: str, date_time_col: str, start_date_col: str, end_date_col: str, easting_col: str, northing_col: str, rainfall_offset_hours: int, output_dir: Path, verbose: bool, correlation_threshold: float, aggregate_gauge_to_daily: bool = True, ): """ Will correlation for multiple station comparing rain gauge data with gridded data nearest to it. TODO: make sure the combining of gauge name is done in order i.e. 1-2 not 2-1 Parameters ---------- gauge_data: Rainfall gauge data gauge_metadata: Details of rain gauge data station_ids_to_correlate: List of ids to run through the correlator. output_dir: Output directory for data files verbose: Whether to print progress as algorithm is run (default: False) Returns ------- """ self.gauge_data = gauge_data self.gauge_metadata = gauge_metadata self.station_ids_to_correlate = station_ids_to_correlate self.gridded_rainfall_data = gridded_rainfall_data self.station_id_col = station_id_col self.precipitation_col = precipitation_col self.gridded_rainfall_col = gridded_rainfall_col self.date_time_col = date_time_col self.start_date_col = start_date_col self.end_date_col = end_date_col self.easting_col = easting_col self.northing_col = northing_col self.rainfall_offset_hours = rainfall_offset_hours self.correlation_threshold = correlation_threshold self.output_dir = output_dir self.verbose = verbose self.aggregate_gauge_to_daily = aggregate_gauge_to_daily def run_correlator(self): station_ids_to_remove = [] for station_id in self.station_ids_to_correlate: # Subset the metadata, data metadata_one_station = self.gauge_metadata.filter(pl.col(self.station_id_col) == station_id) if metadata_one_station.is_empty(): if self.verbose: print(f"Station ID: {station_id} is not included in the metadata") continue data_one_station = self.gauge_data.filter( pl.col(self.station_id_col).is_in(metadata_one_station[self.station_id_col].unique().to_list()) ) if data_one_station.is_empty(): if self.verbose: print(f"Station ID: {station_id} is not included in the data") continue gauge_grid_correlator = GaugeVsGriddedCorrelator( gauge_data=data_one_station, gauge_metadata=metadata_one_station, nearest_gridded_daily=self.gridded_rainfall_data, station_id=station_id, precipitation_col=self.precipitation_col, gridded_rainfall_col=self.gridded_rainfall_col, date_time_col=self.date_time_col, start_date_col=self.start_date_col, end_date_col=self.end_date_col, station_id_col=self.station_id_col, easting_col=self.easting_col, northing_col=self.northing_col, rainfall_offset_hours=self.rainfall_offset_hours, aggregate_gauge_to_daily=self.aggregate_gauge_to_daily, ) try: r_result, rho_result = gauge_grid_correlator.get_corr() except ValueError as ve: station_ids_to_remove.append(station_id) if self.verbose: print(station_id, "failed, probably all NaN", ve) print(station_id, "flagged for removal") continue if self.verbose: print(station_id, r_result, rho_result) if r_result > self.correlation_threshold or rho_result > self.correlation_threshold: pass else: station_ids_to_remove.append(station_id) if self.verbose: print(station_id, "flagged for removal") self.corrd_metadata = self.gauge_metadata.filter(~pl.col(self.station_id_col).is_in(station_ids_to_remove)) @classmethod def run(cls, save_metadata: bool, return_metadata: bool, **kwargs) -> None | pl.DataFrame: """ Run the correlator on a list of station IDs and return and/or save the prepared data. Parameters ---------- save_metadata: Whether to save metadata to output directory return_metadata: Whether to return metadata dataframe Returns ------- metadata: Metadata of data run through algorithm """ batch_correlator = cls(**kwargs) if batch_correlator.verbose: print("Starting Gauge vs Gridded Correlator") batch_correlator.run_correlator() if save_metadata: if batch_correlator.verbose: print(f"Saving data to {batch_correlator.output_dir}") batch_correlator.save_corrd_metadata() else: if batch_correlator.verbose: print("Data not saved") if return_metadata: return batch_correlator.corrd_metadata def save_corrd_metadata(self) -> None: if self.corrd_metadata is None: raise RuntimeError("You must call run_correlator() before save_corrd_metadata()") assert len(self.corrd_metadata.filter(pl.col("file_path").is_duplicated())) == 0, ( "Problem with metadata as duplicate filepaths" ) self.corrd_metadata.write_parquet(self.output_dir / "corrd_metadata.parquet") if self.verbose: print( "Gauge metadata filtered by correlation to nearest grid cell available " f"at: {self.output_dir / 'corrd_metadata.parquet'}" )