Source code for odhin.grouping

"""
@author: raphael.bacher@gipsa-lab.fr

Store methods for
- grouping sources to be deblended
- exploring groups
- modifying groups
"""

import logging
from itertools import chain

import numpy as np
from mpdaf.tools import progressbar
from skimage.measure import label, regionprops

from .utils import createIntensityMap

__all__ = ('SourceGroup', 'RegionAttr', 'doGrouping')


[docs]class SourceGroup: __slots__ = ('ID', 'listSources', 'listHST_ID', 'region', 'step', 'nbSources') def __init__(self, ID, listSources, listHST_ID, region, step): self.ID = ID self.listSources = listSources self.listHST_ID = listHST_ID self.step = step self.region = region # RegionAttr region self.nbSources = len(listSources) def __repr__(self): label = 'sources' if self.nbSources > 1 else 'source' return f'<SourceGroup({self.ID}, {self.nbSources} {label})>'
[docs]class RegionAttr: """ Get region attributes from skimage region properties """ __slots__ = ('area', 'centroid', 'sx', 'sy', 'ra', 'dec') def __init__(self, area, centroid, sy, sx): self.area = area self.centroid = centroid self.sy = sy self.sx = sx
[docs] @classmethod def from_skimage(cls, reg): min_row, min_col, max_row, max_col = reg.bbox sy = slice(min_row, max_row) sx = slice(min_col, max_col) return cls(reg.area, reg.centroid, sy, sx)
[docs] def compute_sky_centroid(self, wcs): self.dec, self.ra = wcs.pix2sky(self.centroid)[0]
@property def bbox_area(self): return (self.sy.stop - self.sy.start) * (self.sx.stop - self.sx.start)
[docs] def ensureMinimalBbox(self, min_width, imLabel, min_sky_pixels, margin): """ Ensures that region respects a minimal area and contains at least `min_sky_pixels` sky pixels. """ # First add margin around bounding box ny, nx = imLabel.shape self.sy = slice(int(max(self.sy.start - margin, 0)), int(min(self.sy.stop + margin, ny))) self.sx = slice(int(max(self.sx.start - margin, 0)), int(min(self.sx.stop + margin, nx))) # then check minimal area if self.bbox_area < min_width**2: half_width = min_width // 2 self.sy = slice(int(max(self.centroid[0] - half_width, 0)), int(min(self.centroid[0] + half_width, ny))) self.sx = slice(int(max(self.centroid[1] - half_width, 0)), int(min(self.centroid[1] + half_width, nx))) # then check minimal number of sky pixels nb_pixels = np.sum(imLabel[self.sy, self.sx] == 0) while nb_pixels < min_sky_pixels: min_width = min_width + 1 half_width = min_width // 2 self.sy = slice(int(max(self.centroid[0] - half_width, 0)), int(min(self.centroid[0] + half_width, ny))) self.sx = slice(int(max(self.centroid[1] - half_width, 0)), int(min(self.centroid[1] + half_width, nx))) nb_pixels = np.sum(imLabel[self.sy, self.sx] == 0)
[docs] def convertToHR(self, imHR, imLR): """Convert the bounding box from low resolution (MUSE) to high resolution (HST). """ # compute coordinates of bottom left corners pos = np.array([[self.sy.start, self.sx.start], [self.sy.stop, self.sx.stop]]) - 0.5 # get HR pixel indices hrpix = imHR.wcs.sky2pix(imLR.wcs.pix2sky(pos), nearest=True) sy, sx = (slice(*x) for x in hrpix.T) return sy, sx
[docs]def doGrouping(imHR, segmap, imMUSE, cat, kernel, params, idname='ID', verbose=True): """Compute groups of connected (at the MUSE resolution) sources. The grouping is done in 2 steps, with 2 thresholds. The first one allows to get the groups with bright sources, and the second is needed to get all sources included faint ones. """ logger = logging.getLogger(__name__) if len(params.cut) != 2: raise ValueError(f'the cut param must contain 2 values') groups = [] im_label_comb = None for it in range(2): logger.info('Create intensity map') intensityMapLRConvol = createIntensityMap(imHR, segmap, imMUSE, kernel, params) logger.info('Compute label image') im_label = label(intensityMapLRConvol > params.cut[it]) # combine label images from the two steps if it == 0: offset_label = 0 im_label_comb = im_label else: # compute offset before adding the label image offset_label = im_label_comb.max() im_label_comb += np.where(im_label > 0, im_label + offset_label, 0) logger.info('Compute regions properties') regions = regionprops(im_label) if verbose: regions = progressbar(regions) for skreg in regions: # Build a RegionAttr object from a skimage region region = RegionAttr.from_skimage(skreg) region.compute_sky_centroid(imMUSE.wcs) region.ensureMinimalBbox(params.min_width, im_label, params.min_sky_pixels, params.margin_bbox) blob_mask = (im_label == skreg.label) sub_blob_mask = blob_mask[region.sy, region.sx] subimMUSE = imMUSE[region.sy, region.sx] hy, hx = region.convertToHR(segmap, imMUSE) subsegmap = segmap._data[hy, hx] listHST_ID = np.unique(subsegmap) listHST_ID = listHST_ID[listHST_ID > 0] if listHST_ID.size == 0: # no HST IDs found in the segmap region, which can happen # with false detections when the threshold is low. continue listSources, hstids = getObjsInBlob(idname, cat, sub_blob_mask, subimMUSE, listHST_ID) gid = skreg.label + offset_label if len(listSources) == 1: # FIXME: this should not happen. It seems to happen when # a source is close to an edge, and because the HR to LR # resampling remove the source flux on the edge spaxels. # Should investigate more! logger.warning('found no sources in group %d', gid) groups.append(SourceGroup(gid, listSources, hstids, region, it + 1)) # build the list of all IDs that are included in a group listSources = (grp.listSources for grp in groups) ids_in_groups = set( int(i) for i in chain.from_iterable(listSources) if i != 'bg' ) area = [grp.region.area for grp in groups] nbSources = [grp.nbSources - 1 for grp in groups] # find the IDs that are not in a group tbl = cat.select(imMUSE.wcs, margin=0, mask=imMUSE.mask) missing_ids = sorted(set(tbl[idname].tolist()) - ids_in_groups) logger.info( 'Step %d: %d groups, %d sources, %d missing sources', it + 1, len(groups), len(ids_in_groups), len(missing_ids), ) logger.info('Area: min=%d max=%d', min(area), max(area)) logger.info('Nb sources: min=%d max=%d', min(nbSources), max(nbSources)) if it == 0: logging.info('Mask the HR image to keep only the missing sources') missing_map = np.logical_or.reduce( [segmap._data == i for i in missing_ids] ).astype(int) imHR = imHR * missing_map segmap = segmap.copy() segmap._data *= missing_map return groups, im_label_comb, missing_ids
def getObjsInBlob(idname, cat, sub_blob_mask, subimMUSE, listHST_ID): """Return the index and IDs of sources in the blobs. Returns ------- listHSTObjInBlob : list of int List of catalog IDs connected to the blob. listHST_ID : list of int List of all catalog IDs in the cutout. """ subcat = cat.loc[idname, listHST_ID] center = np.array([subcat['DEC'], subcat['RA']]).T centerMUSE = subimMUSE.wcs.sky2pix(center, nearest=True).T idx = sub_blob_mask[centerMUSE[0], centerMUSE[1]] # listObjInBlob = [0] + list(np.where(idx)[0] + 1) listHSTObjInBlob = ['bg'] + list(listHST_ID[idx]) listHST_ID = ['bg'] + list(listHST_ID) return listHSTObjInBlob, listHST_ID