Source code for processing_components.visibility.gather_scatter

""" Visibility iterators for iterating through a BlockVisibility or Visibility.

A typical use would be to make a sequence of snapshot visibilitys::

    for rows in vis_timeslice_iter(vt, vis_slices=vis_slices):
        visslice = create_visibility_from_rows(vt, rows)
        dirtySnapshot = create_visibility_from_visibility(visslice, npixel=512, cellsize=0.001, npol=1)
        dirtySnapshot, sumwt = invert_2d(visslice, dirtySnapshot)


"""

import logging
from typing import List

import numpy

from data_models.memory_data_models import Visibility, BlockVisibility
from ..visibility.base import create_visibility_from_rows
from ..visibility.iterators import vis_timeslice_iter, vis_wslice_iter

log = logging.getLogger(__name__)


[docs]def visibility_scatter(vis: Visibility, vis_iter, vis_slices=1) -> List[Visibility]: """Scatter a visibility into a list of subvisibilities If vis_iter is over time then the type of the outvisibilities will be the same as inout If vis_iter is over w then the type of the output visibilities will always be Visibility :param vis: Visibility :param vis_iter: visibility iterator :param vis_slices: Number of slices to be made :return: list of subvisibilitys """ assert vis is not None if vis_slices == 1: return [vis] visibility_list = list() for i, rows in enumerate(vis_iter(vis, vis_slices=vis_slices)): subvis = create_visibility_from_rows(vis, rows) visibility_list.append(subvis) return visibility_list
[docs]def visibility_gather(visibility_list: List[Visibility], vis: Visibility, vis_iter, vis_slices=None) -> Visibility: """Gather a list of subvisibilities back into a visibility The iterator setup must be the same as used in the scatter. :param visibility_list: List of subvisibilities :param vis: Output visibility :param vis_iter: visibility iterator :param vis_slices: Number of slices to be gathered (optional) :return: vis """ if vis_slices == 1: return visibility_list[0] if vis_slices is None: vis_slices = len(visibility_list) rowses = [] for i, rows in enumerate(vis_iter(vis, vis_slices=vis_slices)): rowses.append(rows) for i, rows in enumerate(rowses): assert i < len(visibility_list), "Gather not consistent with scatter for slice %d" % i if visibility_list[i] is not None and numpy.sum(rows): assert numpy.sum(rows) == visibility_list[i].nvis, \ "Mismatch in number of rows (%d, %d) in gather for slice %d" % \ (numpy.sum(rows), visibility_list[i].nvis, i) vis.data[rows] = visibility_list[i].data[...] return vis
def visibility_scatter_w(vis: Visibility, vis_slices=1) -> List[Visibility]: assert isinstance(vis, Visibility), vis return visibility_scatter(vis, vis_iter=vis_wslice_iter, vis_slices=vis_slices) def visibility_scatter_time(vis: Visibility, vis_slices=1) -> List[Visibility]: return visibility_scatter(vis, vis_iter=vis_timeslice_iter, vis_slices=vis_slices) def visibility_gather_w(visibility_list: List[Visibility], vis: Visibility, vis_slices=1) -> Visibility: assert isinstance(vis, Visibility), vis return visibility_gather(visibility_list, vis, vis_iter=vis_wslice_iter, vis_slices=vis_slices) def visibility_gather_time(visibility_list: List[Visibility], vis: Visibility, vis_slices=1) -> Visibility: return visibility_gather(visibility_list, vis, vis_iter=vis_timeslice_iter, vis_slices=vis_slices)
[docs]def visibility_scatter_channel(vis: BlockVisibility) -> List[Visibility]: """ Scatter channels to separate visibilities :param vis: :return: """ assert isinstance(vis, BlockVisibility), vis def extract_channel(v, chan): vis_shape = numpy.array(v.data['vis'].shape) vis_shape[3] = 1 vis = BlockVisibility(data=None, frequency=numpy.array([v.frequency[chan]]), channel_bandwidth=numpy.array([v.channel_bandwidth[chan]]), phasecentre=v.phasecentre, configuration=v.configuration, uvw=v.uvw, time=v.time, vis=v.vis[..., chan, :][..., numpy.newaxis, :], weight=v.weight[..., chan, :][..., numpy.newaxis, :], imaging_weight=v.imaging_weight[..., chan, :][..., numpy.newaxis, :], integration_time=v.integration_time, polarisation_frame=v.polarisation_frame, source=v.source, meta=v.meta) return vis return [extract_channel(vis, channel) for channel, _ in enumerate(vis.frequency)]
[docs]def visibility_gather_channel(vis_list: List[BlockVisibility], vis: BlockVisibility = None): """ Gather a visibility by channel :param vis_list: :param vis: :return: """ cols = ['vis', 'weight'] if vis is None: vis_shape = numpy.array(vis_list[0].vis.shape) vis_shape[-2] = len(vis_list) for v in vis_list: assert len(v.frequency) == 1 assert len(v.channel_bandwidth) == 1 vis = BlockVisibility(data=None, frequency=numpy.array([v.frequency[0] for v in vis_list]), channel_bandwidth=numpy.array([v.channel_bandwidth[0] for v in vis_list]), phasecentre=vis_list[0].phasecentre, configuration=vis_list[0].configuration, uvw=vis_list[0].uvw, time=vis_list[0].time, vis=numpy.zeros(vis_shape, dtype=vis_list[0].vis.dtype), weight=numpy.ones(vis_shape, dtype=vis_list[0].weight.dtype), imaging_weight=numpy.ones(vis_shape, dtype=vis_list[0].weight.dtype), integration_time=vis_list[0].integration_time, polarisation_frame=vis_list[0].polarisation_frame, source=vis_list[0].source, meta=vis_list[0].meta) assert len(vis.frequency) == len(vis_list) for chan, _ in enumerate(vis_list): subvis = vis_list[chan] assert abs(subvis.frequency[0] - vis.frequency[chan]) < 1e-15 for col in cols: vis.data[col][..., chan, :] = subvis.data[col][..., 0, :] vis.frequency[chan] = subvis.frequency[0] nchan = vis.vis.shape[-2] assert nchan == len(vis.frequency) return vis