Source code for odhin.odhin

"""
@author: raphael.bacher@gipsa-lab.fr
"""

import logging
import multiprocessing
import pathlib
import pickle
import warnings
from datetime import datetime

import numpy as np

from astropy.io import fits
from astropy.table import Column, Table, join, vstack
from mpdaf import CPU
from mpdaf.obj import Cube, Image
from mpdaf.sdetect import Catalog
from mpdaf.tools import MpdafUnitsWarning, progressbar

from .deblend import deblendGroup
from .grouping import doGrouping
from .parameters import Params, load_settings
from .utils import (
    calcMainKernelTransfert,
    check_segmap_catalog,
    cmap,
    extractHST,
    get_fig_ax,
)

# Ignore units warnings from MPDAF
warnings.simplefilter('ignore', MpdafUnitsWarning)


def _worker_deblend(group, outfile, conf, imLabel, timestamp):
    try:
        deblendGroup(group, outfile, conf, imLabel, timestamp)
    except Exception:
        logger = logging.getLogger(__name__)
        logger.error('group %d, failed', group.ID, exc_info=True)


[docs]class ODHIN: """ Main class for the deblending process. Parameters ---------- settings_file : str Settings file. output_dir: str if not None, results of each group is saved in this directory Attributes ---------- table_groups : `astropy.table.Table` Table computed by `ODHIN.grouping`, containing the information about groups. table_sources : `astropy.table.Table` Table computed by `ODHIN.deblend`, containing the information about the deblended sources. """ def __init__(self, settings_file, output_dir, idname='ID', raname='RA', decname='DEC'): self.logger = logging.getLogger(__name__) self.logger.debug('loading settings from %s', settings_file) self.output_dir = pathlib.Path(output_dir) self.groups = None self.settings_file = settings_file self.conf = load_settings(settings_file) # if nothing provided take white image of the cube if 'white' in self.conf: self.imMUSE = Image(self.conf['white']) else: cube = Cube(self.conf['cube']) self.imMUSE = cube.sum(axis=0) self.segmap = extractHST(Image(self.conf['segmap']), self.imMUSE, integer_mode=True) # catalog: check for potential discrepancy between the catalog and the # segmentation map (as in Rafelski15) self.cat = Catalog.read(self.conf['catalog']) self.cat.add_index(idname) self.cat.meta['idname'] = self.idname = idname self.cat.meta['raname'] = self.raname = raname self.cat.meta['decname'] = self.decname = decname self.cat = check_segmap_catalog(self.segmap, self.cat, idname=self.idname) self.params = Params(**self.conf.get('params', {})) # Reference HR image ref_info = self.conf['hr_bands'][self.conf['hr_ref_band']] self.imHST = extractHST(Image(ref_info['file']), self.imMUSE) self.imHST.primary_header['photflam'] = ref_info.get('photflam', 1)
[docs] @staticmethod def set_loglevel(level): """Change the logging level.""" logger = logging.getLogger() logger.setLevel(level) logger.handlers[0].setLevel(level)
[docs] def dump(self, filename): """Dump the ODHIN object to a pickle file.""" with open(filename, 'wb') as f: pickle.dump(self, f)
[docs] @classmethod def load(self, filename): """Load an ODHIN object from a pickle file.""" with open(filename, 'rb') as f: self = pickle.load(f) # recreate the group_id index, otherwise it crashes with the default # index implementation self.table_groups.remove_indices('group_id') self.table_groups.add_index('group_id') return self
[docs] def grouping(self, verbose=True, cut=None): """Segment all sources in a number of connected (at the MUSE resolution) groups and build a table of the groups. Parameters ---------- verbose : bool If True, show a progress bar. cut : (float, float) Threshold on the convolved intensity map, to get the segmentation image. """ if cut is not None: self.params.cut = cut # if nothing provided build transfer kernel from default parameters if 'kernel_transfert' in self.conf: kernel_transfert = fits.getdata(self.conf['kernel_transfert']) else: kernel_transfert = calcMainKernelTransfert(self.params, self.imHST) self.groups, self.imLabel, self.missing_ids = doGrouping( self.imHST, self.segmap, self.imMUSE, self.cat, kernel_transfert, self.params, idname=self.idname, verbose=verbose ) rows = [[group.ID, group.nbSources, tuple(group.listSources), group.region.area, group.step] for i, group in enumerate(self.groups)] self.table_groups = Table( names=('group_id', 'nb_sources', 'list_ids', 'area', 'step'), rows=rows, dtype=(int, int, tuple, float, int) ) self.table_groups.add_index('group_id')
[docs] def deblend(self, listGroupToDeblend=None, njobs=None, verbose=True): """Parallelized deblending on a list of groups Parameters ---------- listGroupToDeblend : list List of group IDs to process. If not provided, all groups are processed, starting with the ones with the highest number of sources. njobs : int Number of process to run in parallel. verbose : bool If True, show a progress bar. """ if self.groups is None: raise ValueError("No groups were defined. Please call the " ".grouping() method before doing a deblend") # if no special groups are listed, do on all groups klist = [] slist = [] if listGroupToDeblend is None: for k,group in enumerate(self.groups): if len(group.listSources) == 1: self.logger.warning('skipping group %d, no sources in group', group.ID) continue klist.append(k) slist.append(len(group.listSources)) klist = np.array(klist) ksort = np.argsort(slist)[::-1] klist = klist[ksort] listGroupToDeblend = klist self.output_dir.mkdir(exist_ok=True) timestamp = datetime.now().isoformat() to_process = [] for i in listGroupToDeblend: group = self.groups[i] if len(group.listSources) == 1: self.logger.warning('skipping group %d, no sources in group', group.ID) continue outfile = str(self.output_dir / f'group_{group.ID:05d}.fits') to_process.append((group, outfile, self.conf, self.imLabel, timestamp)) # Determine the number of processes: # - default: all CPUs except one. # - mdaf.CPU # - cpu_count parameter cpu_count = multiprocessing.cpu_count() - 1 if CPU > 0 and CPU < cpu_count: cpu_count = CPU if njobs is not None and njobs < cpu_count: cpu_count = njobs cpu_count = min(cpu_count, len(listGroupToDeblend)) self.logger.debug('using %d cpus', cpu_count) if cpu_count > 1: pool = multiprocessing.Pool(processes=cpu_count) if verbose: ntasks = len(to_process) # add progress bar pbar = progressbar(total=ntasks) def update(*a): pbar.update() else: update = None for args in to_process: pool.apply_async(_worker_deblend, args=args, callback=update) pool.close() pool.join() else: for args in to_process: deblendGroup(*args) self.build_result_table()
[docs] def build_result_table(self): """Build the result table from the sources. This is called at the end of the `~ODHIN.deblend` method. """ tables = [] for f in self.output_dir.glob('group_*.fits'): t = Table.read(f, hdu='TAB_SOURCES') t['timestamp'] = fits.getval(f, 'ODH_TS') tables.append(t) tables = vstack(tables) cat = Table([[str(x) for x in self.cat[self.idname]], self.cat[self.raname], self.cat[self.decname]], names=('id', 'ra', 'dec')) # join with input catalog (inner join to get only the processed ids, # and without the bg_* rows) self.table_sources = join(tables, cat, keys=['id'], join_type='inner') # cast id column to integer self.table_sources.replace_column( 'id', Column(data=[int(x) for x in self.table_sources['id']]) ) self.table_sources.sort('group_id') return self.table_sources
[docs] def plotGroups(self, ax=None, groups=None, linewidth=1): """Plot the segmentation map and groups. Parameters ---------- ax : matplotlib axis Axis to use for the plot. groups : list List of groups. """ import matplotlib.patches as mpatches ax = get_fig_ax(ax) cm = cmap(self.imLabel.max(), random_state=12345) ax.imshow(self.imLabel, cmap=cm, origin='lower') if groups is None: groups = self.groups for group in groups: minr, maxr = group.region.sy.start, group.region.sy.stop minc, maxc = group.region.sx.start, group.region.sx.stop rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor='red', linewidth=linewidth) ax.add_patch(rect)
[docs] def plotAGroup(self, ax=None, group_id=None, cmap='Greys', **kwargs): """Plot a group, with sources positions and contour of the label image. Parameters ---------- ax : matplotlib axis Axis to use for the plot. group_id : int Group id. cmap : str Colormap for the image plot. kwargs : dict Passed to Image.plot. """ assert group_id is not None ax = get_fig_ax(ax) group = self.groups[group_id - 1] reg = group.region subim = self.imMUSE[reg.sy, reg.sx] subim.plot(ax=ax, cmap='Greys', **kwargs) ax.contour(self.imLabel[reg.sy, reg.sx] == group.ID, levels=1, colors='r') src = group.listSources.copy() if 'bg' in src: src.remove('bg') self.cat.plot_symb(ax, subim.wcs, label=True, esize=0.4) cat = self.cat[np.in1d(self.cat[self.idname], src)] y, x = subim.wcs.sky2pix(np.array([cat[self.decname], cat[self.raname]]).T).T ax.scatter(x, y, c="r")
[docs] def plotHistArea(self, ax=None, nbins=None): """Plot histogram of group areas. Parameters ---------- ax : matplotlib axis Axis to use for the plot. nbins : str or int Number of bins for `matplotlib.pyplot.hist`. """ ax = get_fig_ax(ax) if nbins is None: nbins = int(self.table_groups['area'].max()) // 50 ax.hist([group.region.area for group in self.groups], bins=nbins) ax.set_title('Histogram of group areas')
[docs] def plotHistNbS(self, ax=None, nbins=None): """Plot histogram of the number of sources per group. Parameters ---------- ax : matplotlib axis Axis to use for the plot. nbins : str or int Number of bins for `matplotlib.pyplot.hist`. """ ax = get_fig_ax(ax) if nbins is None: nbins = self.table_groups['nb_sources'].max() ax.hist([group.nbSources for group in self.groups], bins=nbins) ax.set_title('Histogram of sources number per group')