Source code for holoviews.plotting.bokeh.graphs

import param
import numpy as np

from bokeh.models import Range1d, Circle, MultiLine, HoverTool, ColumnDataSource

try:
    from bokeh.models import (StaticLayoutProvider, GraphRenderer, NodesAndLinkedEdges,
                              EdgesAndLinkedNodes)
except:
    pass

from ...core.options import abbreviated_exception, SkipRendering
from ...core.util import basestring, dimension_sanitizer
from .chart import ColorbarPlot, PointPlot
from .element import CompositeElementPlot, LegendPlot, line_properties, fill_properties, property_prefixes
from .util import mpl_to_bokeh


class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):

    color_index = param.ClassSelector(default=None, class_=(basestring, int),
                                      allow_None=True, doc="""
      Index of the dimension from which the color will the drawn""")

    selection_policy = param.ObjectSelector(default='nodes', objects=['edges', 'nodes', None], doc="""
        Determines policy for inspection of graph components, i.e. whether to highlight
        nodes or edges when selecting connected edges and nodes respectively.""")

    inspection_policy = param.ObjectSelector(default='nodes', objects=['edges', 'nodes', None], doc="""
        Determines policy for inspection of graph components, i.e. whether to highlight
        nodes or edges when hovering over connected edges and nodes respectively.""")

    tools = param.List(default=['hover', 'tap'], doc="""
        A list of plugin tools to use on the plot.""")

    # X-axis is categorical
    _x_range_type = Range1d

    # Declare that y-range should auto-range if not bounded
    _y_range_type = Range1d

        # Map each glyph to a style group
    _style_groups = {'scatter': 'node', 'multi_line': 'edge'}

    style_opts = (['edge_'+p for p in line_properties] +\
                  ['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap'])

    def _hover_opts(self, element):
        if self.inspection_policy == 'nodes':
            dims = element.nodes.dimensions()
            dims = [(dims[2].pprint_label, '@{index_hover}')]+dims[3:]
        elif self.inspection_policy == 'edges':
            kdims = [(kd.pprint_label, '@{%s}' % ref)
                     for kd, ref in zip(element.kdims, ['start', 'end'])]
            dims = kdims+element.vdims
        else:
            dims = []
        return dims, {}

    def get_extents(self, element, ranges):
        """
        Extents are set to '' and None because x-axis is categorical and
        y-axis auto-ranges.
        """
        xdim, ydim = element.nodes.kdims[:2]
        x0, x1 = ranges[xdim.name]
        y0, y1 = ranges[ydim.name]
        return (x0, y0, x1, y1)

    def _get_axis_labels(self, *args, **kwargs):
        """
        Override axis labels to group all key dimensions together.
        """
        element = self.current_frame
        xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]]
        return xlabel, ylabel, None

    def get_data(self, element, ranges, style):
        xidx, yidx = (1, 0) if self.invert_axes else (0, 1)

        # Get node data
        nodes = element.nodes.dimension_values(2)
        node_positions = element.nodes.array([0, 1, 2])
        # Map node indices to integers
        if nodes.dtype.kind != 'i':
            node_indices = {v: i for i, v in enumerate(nodes)}
            index = np.array([node_indices[n] for n in nodes], dtype=np.int32)
            layout = {node_indices[z]: (y, x) if self.invert_axes else (x, y)
                      for x, y, z in node_positions}
        else:
            index = nodes.astype(np.int32)
            layout = {z: (y, x) if self.invert_axes else (x, y)
                      for x, y, z in node_positions}
        point_data = {'index': index}
        cdata, cmapping = self._get_color_data(element.nodes, ranges, style, 'node_fill_color')
        point_data.update(cdata)
        point_mapping = cmapping
        if 'node_fill_color' in point_mapping:
            style = {k: v for k, v in style.items() if k not in
                     ['node_fill_color', 'node_nonselection_fill_color']}
            point_mapping['node_nonselection_fill_color'] = point_mapping['node_fill_color']

        # Get edge data
        nan_node = index.max()+1
        start, end = (element.dimension_values(i) for i in range(2))
        if nodes.dtype.kind not in 'if':
            start = np.array([node_indices.get(x, nan_node) for x in start], dtype=np.int32)
            end = np.array([node_indices.get(y, nan_node) for y in end], dtype=np.int32)
        path_data = dict(start=start, end=end)
        if element._edgepaths and not self.static_source:
            edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims)
            if len(edges) == len(start):
                path_data['xs'] = [path[:, 0] for path in edges]
                path_data['ys'] = [path[:, 1] for path in edges]
            else:
                self.warning('Graph edge paths do not match the number of abstract edges '
                             'and will be skipped')

        # Get hover data
        if any(isinstance(t, HoverTool) for t in self.state.tools):
            if self.inspection_policy == 'nodes':
                index_dim = element.nodes.get_dimension(2)
                point_data['index_hover'] = [index_dim.pprint_value(v) for v in element.nodes.dimension_values(2)]
                for d in element.nodes.dimensions()[3:]:
                    point_data[dimension_sanitizer(d.name)] = element.nodes.dimension_values(d)
            elif self.inspection_policy == 'edges':
                for d in element.vdims:
                    path_data[dimension_sanitizer(d.name)] = element.dimension_values(d)

        data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout}
        mapping = {'scatter_1': point_mapping, 'multi_line_1': {}}
        return data, mapping, style


    def _update_datasource(self, source, data):
        """
        Update datasource with data for a new frame.
        """
        if isinstance(source, ColumnDataSource):
            source.data.update(data)
        else:
            source.graph_layout = data


    def _init_glyphs(self, plot, element, ranges, source):
        # Get data and initialize data source
        style = self.style[self.cyclic_index]
        data, mapping, style = self.get_data(element, ranges, style)
        self.handles['previous_id'] = element._plot_id
        properties = {}
        mappings = {}
        for key in mapping:
            source = self._init_datasource(data.get(key, {}))
            self.handles[key+'_source'] = source
            glyph_props = self._glyph_properties(plot, element, source, ranges, style)
            properties.update(glyph_props)
            mappings.update(mapping.get(key, {}))
        properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')}
        properties.update(mappings)

        # Define static layout
        layout = StaticLayoutProvider(graph_layout=data['layout'])
        node_source = self.handles['scatter_1_source']
        edge_source = self.handles['multi_line_1_source']
        renderer = plot.graph(node_source, edge_source, layout, **properties)

        # Initialize GraphRenderer
        if self.selection_policy == 'nodes':
            renderer.selection_policy = NodesAndLinkedEdges()
        elif self.selection_policy == 'edges':
            renderer.selection_policy = EdgesAndLinkedNodes()
        else:
            renderer.selection_policy = None

        if self.inspection_policy == 'nodes':
            renderer.inspection_policy = NodesAndLinkedEdges()
        elif self.inspection_policy == 'edges':
            renderer.inspection_policy = EdgesAndLinkedNodes()
        else:
            renderer.inspection_policy = None

        self.handles['layout_source'] = layout
        self.handles['glyph_renderer'] = renderer
        self.handles['scatter_1_glyph_renderer'] = renderer.node_renderer
        self.handles['multi_line_1_glyph_renderer'] = renderer.edge_renderer
        self.handles['scatter_1_glyph'] = renderer.node_renderer.glyph
        self.handles['multi_line_1_glyph'] = renderer.edge_renderer.glyph
        if 'hover' in self.handles:
            self.handles['hover'].renderers.append(renderer)


[docs]class NodePlot(PointPlot): """ Simple subclass of PointPlot which hides x, y position on hover. """ def _hover_opts(self, element): return element.dimensions()[2:], {}