diff --git a/lib/datasets.py b/lib/datasets.py index efdcf1c..a937693 100644 --- a/lib/datasets.py +++ b/lib/datasets.py @@ -36,6 +36,9 @@ def filter_reject_large(mats): def filter_reject_small(mats): return [mat for mat in mats if not mat_is_small(mat)] +def filter_keep_square(mats): + return [mat for mat in mats if mat.rows == mat.cols] + ## all real-valued matrices REAL_MATS = Dataset( name = "reals", @@ -80,6 +83,15 @@ REAL_SMALL_MATS = Dataset ( mats = filter_reject_large(REAL_MATS.mats) ) +REGULAR_SQUARE_REAL_SMALL_MATS = Dataset ( + name = "regular_square_reals_small", + mats = filter_keep_square(REGULAR_REAL_SMALL_MATS.mats) +) +SQUARE_REAL_SMALL_MATS = Dataset ( + name = "square_reals_small", + mats = filter_keep_square(REAL_SMALL_MATS.mats) +) + ## keep "medium" matrices REGULAR_REAL_MED_MATS = Dataset ( name = "regular_reals_med", @@ -97,7 +109,9 @@ DATASETS = [ REAL_MED_MATS, # REGULAR_REAL_MATS, REGULAR_REAL_SMALL_MATS, - REGULAR_REAL_MED_MATS + REGULAR_REAL_MED_MATS, + REGULAR_SQUARE_REAL_SMALL_MATS, + SQUARE_REAL_SMALL_MATS, ] def get_kinds():