TensorFlow Plot (tfplot)

2022-09-04 21:43:42 浏览数 (1)

原文链接:https://tensorflow-plot.readthedocs.io/en/latest/api/index.html


1. Showcases of tfplot

This guide shows a quick tour of the tfplot library. Please skip the setup section of this document.

代码语言:javascript复制
import tfplot
tfplot.__version__
代码语言:javascript复制
'0.3.0.dev0'

1. Setup: Utilities and Data

In order to see the images generated from the plot ops, we introduce a simple utility function which takes a Tensor as an input and displays the resulting image after executing it in a TensorFlow session.

You may want to skip this section to have the showcase started.

代码语言:javascript复制
import tensorflow as tf
sess = tf.InteractiveSession()
代码语言:javascript复制
def execute_op_as_image(op):
    """
    Evaluate the given `op` and return the content PNG image as `PIL.Image`.

    - If op is a plot op (e.g. RGBA Tensor) the image or
      a list of images will be returned
    - If op is summary proto (i.e. `op` was a summary op),
      the image content will be extracted from the proto object.
    """
    print ("Executing: "   str(op))
    ret = sess.run(op)
    plt.close()

    if isinstance(ret, np.ndarray):
        if len(ret.shape) == 3:
            # single image
            return Image.fromarray(ret)
        elif len(ret.shape) == 4:
            return [Image.fromarray(r) for r in ret]
        else:
            raise ValueError("Invalid rank : %d" % len(ret.shape))

    elif isinstance(ret, (str, bytes)):
        from io import BytesIO
        s = tf.Summary()
        s.ParseFromString(ret)
        ims = []
        for i in range(len(s.value)):
            png_string = s.value[i].image.encoded_image_string
            im = Image.open(BytesIO(png_string))
            ims.append(im)
        plt.close()
        if len(ims) == 1: return ims[0]
        else: return ims

    else:
        raise TypeError("Unknown type: "   str(ret))
代码语言:javascript复制
and some data:
代码语言:javascript复制
def fake_attention():
    import scipy.ndimage
    attention = np.zeros([16, 16], dtype=np.float32)
    attention[(11, 8)] = 1.0
    attention[(9, 9)] = 1.0
    attention = scipy.ndimage.filters.gaussian_filter(attention, sigma=1.5)
    return attention

sample_image = scipy.misc.face()
attention_map = fake_attention()

# display the data
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(sample_image); axs[0].set_title('image')
axs[1].imshow(attention_map, cmap='jet'); axs[1].set_title('attention')
plt.show()

And we finally wrap these numpy values into TensorFlow ops:

代码语言:javascript复制
# the input to plot_op
image_tensor = tf.constant(sample_image, name='image')
attention_tensor = tf.constant(attention_map, name='attention')
print(image_tensor)
print(attention_tensor)

Tensor("image:0", shape=(768, 1024, 3), dtype=uint8)
Tensor("attention:0", shape=(16, 16), dtype=float32)

2. tfplot.autowrap: The Main End-User API

Use tfplot.autowrap to design a custom plot function of your own.

1. Decorator to define a TF op that draws plot

With tfplot.autowrap, you can wrap a python function that returns matplotlib.Figure (or AxesSubPlot) into TensorFlow ops, similar as in tf.py_func.

代码语言:javascript复制
@tfplot.autowrap
def plot_scatter(x, y):
    # NEVER use plt.XXX, or matplotlib.pyplot.
    # Use tfplot.subplots() instead of plt.subplots() to avoid thread-safety issues.
    fig, ax = tfplot.subplots(figsize=(3, 3))
    ax.scatter(x, y, color='green')
    return fig

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))

Executing: Tensor("plot_scatter:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制

We can create subplots as well. Also, note that additional arguments (i.e. kwargs) other than Tensor arguments (i.e. positional arguments) can be passed.

代码语言:javascript复制
@tfplot.autowrap
def plot_image_and_attention(im, att, cmap=None):
    fig, axes = tfplot.subplots(1, 2, figsize=(7, 4))
    fig.suptitle('Image and Heatmap')
    axes[0].imshow(im)
    axes[1].imshow(att, cmap=cmap)
    return fig

op = plot_image_and_attention(sample_image, attention_map, cmap='jet')
execute_op_as_image(op)

Executing: Tensor("plot_image_and_attention:0", shape=(?, ?, 4), dtype=uint8)

Sometimes, it can be cumbersome to create instances of fig and ax. If you want to have them automatically created and injected, use a keyword argument named fig and/or ax:

代码语言:javascript复制
@tfplot.autowrap(figsize=(2, 2))
def plot_scatter(x, y, *, ax, color='red'):
    ax.set_title('x^2')
    ax.scatter(x, y, color=color)

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))

Executing: Tensor("plot_scatter_1:0", shape=(?, ?, 4), dtype=uint8)

2. Wrapping Matplotlib’s AxesPlot or Seaborn Plot

You can use tfplot.autowrap (or raw APIs such as tfplot.plot, etc.) to plot anything by writing a customized plotting function on your own, but sometimes we may want to convert already existing plot functions from common libraries such as matplotlib and seaborn.

To do this, you can still use tfplot.autowrap.

1. Matplotlib

Matplotlib provides a variety of plot methods defined in the class AxesPlot (usually, ax).

代码语言:javascript复制
rs = np.random.RandomState(42)
x = rs.randn(100)
y = 2 * x   rs.randn(100)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.set_title("Created from matplotlib API")
plt.show()

We can wrap the Axes.scatter() method as TensorFlow op as follows:

代码语言:javascript复制
from matplotlib.axes import Axes
tf_scatter = tfplot.autowrap(Axes.scatter, figsize=(4, 4))

plot_op = tf_scatter(x, y)
execute_op_as_image(plot_op)

Executing: Tensor("scatter:0", shape=(?, ?, 4), dtype=uint8)

2. Seaborn

Seaborn provides many useful axis plot functions that can be used out-of-box. Most of functions for drawing an AxesPlot will have the ax=... parameter.

See seaborn’s example gallery for interesting features seaborn provides.

代码语言:javascript复制
import seaborn as sns

assert sns.__version__ >= '0.8', 
    'Use seaborn >= v0.8.0, otherwise `import seaborn as sns` will affect the default matplotlib style.'

barplot: (Discrete) Probability Distribution

代码语言:javascript复制
# https://seaborn.pydata.org/generated/seaborn.barplot.html

y = np.random.RandomState(42).normal(size=[18])
y = np.exp(y) / np.exp(y).sum() # softmax
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.constant(ATARI_ACTIONS)

op = tfplot.autowrap(sns.barplot, palette='Blues_d')(x, y)
execute_op_as_image(op)

Executing: Tensor("barplot:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制
y = np.random.RandomState(42).normal(size=[3, 18])

y = np.exp(y) / np.exp(y).sum(axis=1).reshape([-1, 1]) # softmax example-wise
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.broadcast_to(tf.constant(ATARI_ACTIONS), y.shape)

op = tfplot.autowrap(sns.barplot, palette='Blues_d', batch=True)(x, y)
for im in execute_op_as_image(op):
    display(im)

Executing: Tensor("barplot_1/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
代码语言:javascript复制

HeatmapLet’s wrap seaborn’s heatmap function, as TensorFlow operation, with some additional default kwargs. This is very useful for visualization.

代码语言:javascript复制
# @seealso https://seaborn.pydata.org/examples/heatmap_annotation.html
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(9, 6))

op = tf_heatmap(attention_map, cbar=True, annot=True, fmt=".2f")
execute_op_as_image(op)

Executing: Tensor("heatmap:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制
代码语言:javascript复制

What if we don’t want axes and colorbars, but only the map itself? Compare to plain tf.summary.image, which just gives a grayscale image.

代码语言:javascript复制
# print only heatmap figures other than axis, colorbar, etc.
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(4, 4), tight_layout=True,
                             cmap='jet', cbar=False, xticklabels=False, yticklabels=False)

op = tf_heatmap(attention_map, name='HeatmapImage')
execute_op_as_image(op)

Executing: Tensor("HeatmapImage:0", shape=(?, ?, 4), dtype=uint8)

3. And Many More!

This document has covered a basic usage of tfplot, but there are a few more:

  • tfplot.contrib: contains some off-the-shelf functions for creating plot operations that can be useful in practice, in few lines (without a hassle of writing function body). See [contrib.ipynb] for more tour of available APIs.
  • tfplot.plot(), tfplot.plot_many(), etc.: Low-level APIs.
  • tfplot.summary: One-liner APIs for creating TF summary operations.
代码语言:javascript复制
import tfplot.contrib

For example, probmap and probmap_simple create an image Tensor that visualizes a probability map:

代码语言:javascript复制
op = tfplot.contrib.probmap(attention_map, figsize=(4, 3))
execute_op_as_image(op)
Executing: Tensor("probmap:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制
op = tfplot.contrib.probmap_simple(attention_map, figsize=(3, 3), vmin=0, vmax=1)
execute_op_as_image(op)

Executing: Tensor("probmap_1:0", shape=(?, ?, 4), dtype=uint8)

That’s all! Please take a look at API documentations and more examples if you are interested.

2. tfplot.contrib: Some pre-defined plot ops

The tfplot.contrib package contains some off-the-shelf functions for defining plotting operations. This package provides some off-the-shelf functions that could be useful widely across many typical use cases.Unfortunately, it may not provide super flexible and fine-grained customization points beyond the current parameters. If it does not fit what you want to get, then consider designing your own plotting functions using tfplot.autowrap.

代码语言:javascript复制
import tfplot.contrib

for fn in sorted(tfplot.contrib.__all__):
    print("%-20s" % fn, tfplot.contrib.__dict__[fn].__doc__.split('n')[1].strip())

batch                Make an autowrapped plot function (... -> RGBA tf.Tensor) work in a batch
probmap              Display a heatmap in color. The resulting op will be a RGBA image Tensor.
probmap_simple       Display a heatmap in color, but only displays the image content.

1. probmap

For example, probmap and probmap_simple create an image Tensor that visualizes a probability map:

代码语言:javascript复制
attention_op = tf.constant(attention_map, name="attention_op")
print(attention_op)

op = tfplot.contrib.probmap(attention_map, figsize=(4, 3))
execute_op_as_image(op)

Tensor("attention_op:0", shape=(16, 16), dtype=float32)
Executing: Tensor("probmap:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制
代码语言:javascript复制
op = tfplot.contrib.probmap_simple(attention_map, figsize=(3, 3),
                                   vmin=0, vmax=1)
execute_op_as_image(op)

Executing: Tensor("probmap_1:0", shape=(?, ?, 4), dtype=uint8)

2. Auto-batch mode (tfplot.contrib.batch)

In many cases, we may want to make plotting operations behave in a batch manner. You can use tfplot.contrib.batch to make those functions work in a batch mode:

代码语言:javascript复制
# batch version
N = 5
p = np.zeros([N, N, N])
for i in range(N):
    p[i, i, i] = 1.0

p = tf.constant(p, name="batch_tensor"); print(p)                      # (batch_size, 5, 5)
op = tfplot.contrib.batch(tfplot.contrib.probmap)(p, figsize=(3, 2))   # (batch_size, H, W, 4)

results = execute_op_as_image(op)      # list of N images
Image.fromarray(np.hstack([np.asarray(im) for im in results]))

Tensor("batch_tensor:0", shape=(5, 5, 5), dtype=float64)
Executing: Tensor("probmap_2/PlotImages:0", shape=(5, ?, ?, 4), dtype=uint8)

3. More APIs

1. Low-level APIs: tfplot.plot()

The following examples show the usage of the most general form of the API, tfplot.plot(). It has a very similar usage as tf.py_func().

Conceptually, we can draw any matplotlib plot as a TensorFlow op. One thing to remember is that the plot_func function (passed to tfplot.plot()) should be implemented using object-oriented APIs of matplotlib, not pyplot.XXX APIs (or matplotlib.pyplot.XXX) in order to avoid thread-safety issues.

1. A basic example

代码语言:javascript复制
def test_figure():
    fig, ax = tfplot.subplots(figsize=(3, 3))
    ax.text(0.5, 0.5, "Hello World!",
            ha='center', va='center', size=24)
    return fig

plot_op = tfplot.plot(test_figure, [])
execute_op_as_image(plot_op)

Executing: Tensor("Plot:0", shape=(?, ?, 4), dtype=uint8)

2. with Arguments

代码语言:javascript复制
def figure_attention(attention):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.imshow(attention, cmap='jet')
    fig.colorbar(im)
    return fig

plot_op = tfplot.plot(figure_attention, [attention_tensor])
execute_op_as_image(plot_op)

Executing: Tensor("Plot_1:0", shape=(?, ?, 4), dtype=uint8)

3. Examples of using kwargs

代码语言:javascript复制
# the plot function can have additional kwargs for providing configuration points
def overlay_attention(attention, image,
                      alpha=0.5, cmap='jet'):
    fig = tfplot.Figure(figsize=(4, 4))
    ax = fig.add_subplot(1, 1, 1)
    ax.axis('off')
    fig.subplots_adjust(0, 0, 1, 1)  # get rid of margins

    H, W = attention.shape
    ax.imshow(image, extent=[0, H, 0, W])
    ax.imshow(attention, cmap=cmap,
              alpha=alpha, extent=[0, H, 0, W])
    return fig
代码语言:javascript复制
代码语言:javascript复制
plot_op = tfplot.plot(overlay_attention, [attention_tensor, image_tensor])
execute_op_as_image(plot_op)

Executing: Tensor("Plot_2:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制
# the kwargs to `tfplot.plot()` are passed to the plot function (i.e. `overlay_attention`)
# during the execution of the plot operation.
plot_op = tfplot.plot(overlay_attention, [attention_tensor, image_tensor],
                      cmap='gray', alpha=0.8)
execute_op_as_image(plot_op)

Executing: Tensor("Plot_3:0", shape=(?, ?, 4), dtype=uint8)

4. plot_many() – the batch version

代码语言:javascript复制
# make a fake batch
batch_size = 3
attention_batch = tf.random_gamma([batch_size, 7, 7], alpha=0.3, seed=42)
image_batch = tf.tile(tf.expand_dims(image_tensor, 0),
                      [batch_size, 1, 1, 1], name='image_batch')
print (attention_batch)
print (image_batch)

# plot_many()
plot_op = tfplot.plot_many(overlay_attention, [attention_batch, image_batch])
images = execute_op_as_image(plot_op)

Tensor("random_gamma/Maximum:0", shape=(3, 7, 7), dtype=float32)
Tensor("image_batch:0", shape=(3, 768, 1024, 3), dtype=uint8)
Executing: Tensor("PlotMany/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
代码语言:javascript复制
# just see the three images
_, axes = plt.subplots(1, 3, figsize=(10, 3))
for i in range(3):
    axes[i].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i].imshow(images[i])
plt.show()

5. Wrap once, use it as a factory – tfplot.autowrap() or tfplot.wrap()

Let’s wrap the function overlay_attention, which

  • takes a heatmap attention and a RGB image overlay_image
  • and plots the heatmap on top of the image

as Tensors:

代码语言:javascript复制
plot_op = tfplot.autowrap(overlay_attention)(attention_tensor, image_tensor)
execute_op_as_image(plot_op)

Executing: Tensor("overlay_attention:0", shape=(?, ?, 4), dtype=uint8)
代码语言:javascript复制

More clean style in a functional way!

6. Batch example

代码语言:javascript复制
tf_plot_attention = tfplot.wrap(overlay_attention, name='PlotAttention', batch=True)
print (tf_plot_attention)

<function wrap[__main__.overlay_attention] at 0x127f26f28>

Then we can call the resulting tf_plot_attention function to build new TensorFlow ops:

代码语言:javascript复制
plot_op = tf_plot_attention(attention_batch, image_batch)
images = execute_op_as_image(plot_op)
images

Executing: Tensor("PlotAttention/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
代码语言:javascript复制
代码语言:javascript复制
[<PIL.Image.Image image mode=RGBA size=288x288 at 0x12A896470>,
 <PIL.Image.Image image mode=RGBA size=288x288 at 0x12A896390>,
 <PIL.Image.Image image mode=RGBA size=288x288 at 0x12A8962E8>]
代码语言:javascript复制
# just see the three images
_, axes = plt.subplots(1, 3, figsize=(10, 3))
for i in range(3):
    axes[i].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i].imshow(images[i])
plt.show()

2. tfplot.summary (deprecated)

Finally, we can directly create a TensorFlow summary op from input tensors. This will give a similar API usage as tf.summary.image(), which is a shortcut to creating plot ops and then creating image summaries.

代码语言:javascript复制
import tfplot.summary

1. tfplot.summary.plot()

代码语言:javascript复制
# Just directly add a single plot result into a summary
summary_op = tfplot.summary.plot("plot_summary", test_figure, [])
print(summary_op)
execute_op_as_image(summary_op)

Tensor("plot_summary/ImageSummary:0", shape=(), dtype=string)
Executing: Tensor("plot_summary/ImageSummary:0", shape=(), dtype=string)

2. tfplot.summary.plot_many() – the batch version

代码语言:javascript复制
# batch of attention maps --> image summary
batch_size, H, W = 4, 4, 4
batch_attentions = np.zeros((batch_size, H, W), dtype=np.float32)
for b in range(batch_size):
    batch_attentions[b, b, b] = 1.0

# Note that tfplot.summary.plot_many() takes an input in a batch form
def figure_attention_demo2(attention):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.imshow(attention, cmap='jet')
    fig.colorbar(im)
    return fig
summary_op = tfplot.summary.plot_many("batch_attentions_summary", figure_attention_demo2,
                                      [batch_attentions], max_outputs=4)
print(summary_op)
images = execute_op_as_image(summary_op)

Tensor("batch_attentions_summary/ImageSummary:0", shape=(), dtype=string)
Executing: Tensor("batch_attentions_summary/ImageSummary:0", shape=(), dtype=string)
代码语言:javascript复制
代码语言:javascript复制
# just see the 4 images in the summary
_, axes = plt.subplots(2, 2, figsize=(8, 6))
for i in range(batch_size):
    axes[i//2, i%2].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i//2, i%2].imshow(images[i])
plt.show()

3. API Reference

1. tfplot

1.Wrapper functions

tfplot.autowrap(*args, **kwargs)[source]

Wrap a function as a TensorFlow operation similar to tfplot.wrap() (as a decorator or with normal function call), but provides with additional features such as auto-creating matplotlib figures.

  • (fig, ax) matplotlib objects are automatically created and injected given that plot_func has a keyword argument named fig and/or `ax. In such cases, we do not need to manually call tfplot.subplots() to create matplotlib figure/axes objects. If a manual creation of fig, ax is forced, please consider using tfplot.wrap() instead.
  • It can automatically handle return values of the provided plot_func function. If it returns nothing (None) but fig was automatically injected then the resulting figure will be drawn, or returns Axes then the associated Figure will be used.

Example

代码语言:javascript复制
>>> @tfplot.autowrap(figsize=(3, 3))
>>> def plot_imshow(img, *, fig, ax):
>>>    ax.imshow(img)
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. Additionally, if this function has a parameter named fig and/or ax, new instances of Figure and/or AxesSubplot will be created and passed. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of plot_func will be used. figsize – The figure size for the figure to be created. tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1). kwargs_default – An optimal kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. Additionally, if this function has a parameter named fig and/or ax, new instances of Figure and/or AxesSubplot will be created and passed.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs_default – An optimal kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

tfplot.wrap(*args, **kwargs)[source]

Wrap a plot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input. It can be also used as a decorator.

For example:

代码语言:javascript复制
>>> @tfplot.wrap
>>> def plot_imshow(img):
>>>    fig, ax = tfplot.subplots()
>>>    ax.imshow(img)
>>>    return fig
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)

Or, if plot_func is a python function that takes numpy arrays as input and draw a plot by returning a matplotlib Figure, we can wrap this function as a Tensor factory, such as:

代码语言:javascript复制
>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
>>> # x, y = get_batch_inputs(batch_size=4, ...)
>>> plot_x = tf_plot(x)
Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
>>> plot_y = tf_plot(y)
Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of plot_func will be used. kwargs – An optional kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

Returns:

A python function that will create a TensorFlow plot operation, passing the provided arguments.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • kwargs – An optional kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

Returns: A python function that will create a TensorFlow plot operation, passing the provided arguments.

tfplot.wrap_axesplot(axesplot_func, _sentinel=None, batch=False, name=None, figsize=None, tight_layout=False, **kwargs)[source]

DEPRECATED: Use tfplot.autowrap() instead. Will be removed in the next version.

Wrap an axesplot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input.

An axesplot function axesplot_func can be either:

  • an unbounded method of matplotlib Axes (or AxesSubplot) class, such as Axes.scatter() and Axes.text(), etc, or
  • a simple python function that takes the named argument ax, of type Axes or AxesSubplot, on which the plot will be drawn. Some good examples of this family includes seaborn.heatmap(ax=...).

The resulting function can be used as a Tensor factory. When the created tensorflow plot op is being executed, a new matplotlib figure which consists of a single AxesSubplot will be created, and the axes plot will be used as an argument for axesplot_func. For example,

代码语言:javascript复制
>>> import seaborn.apionly as sns
>>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')

>>> plot_op = tf_heatmap(attention_map, cmap)
Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)

Parameters:

axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of axesplot_func will be used. figsize – The figure size for the figure to be created. tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1). kwargs – An optional kwargs that will be passed by default to axesplot_func.

Returns:

A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into axesplot_func.

  • axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of axesplot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs – An optional kwargs that will be passed by default to axesplot_func.

Returns: A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into axesplot_func.

2. Raw Plot Ops

tfplot.plot(plot_func, in_tensors, name='Plot', **kwargs)[source]

Create a TensorFlow op which draws plot in an image. The resulting image is in a 3-D uint8 tensor.

Given a python function plot_func, which takes numpy arrays as its inputs (the evaluations of in_tensors) and returns a matplotlib Figure object as its outputs, wrap this function as a TensorFlow op. The returning figure will be rendered as a RGB-A image upon execution.

Parameters:

plot_func – a python function or callable The function which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. in_tensors – A list of tf.Tensor objects. name – A name for the operation (optional). kwargs – Additional keyword arguments passed to plot_func (optional).

Returns:

A single uint8 Tensor of shape (?, ?, 4), containing the plot image that plot_func computes.

  • plot_func – a python function or callable The function which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).

Returns: A single uint8 Tensor of shape (?, ?, 4), containing the plot image that plot_func computes.

tfplot.plot_many(plot_func, in_tensors, name='PlotMany', max_outputs=None, **kwargs)[source]

A batch version of plot. Create a TensorFlow op which draws a plot for each image. The resulting images are given in a 4-D uint8 Tensor of shape [batch_size, height, width, 4].

Parameters:

plot_func – A python function or callable, which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same. in_tensors – A list of tf.Tensor objects. name – A name for the operation (optional). max_outputs – Max number of batch elements to generate plots for (optional). kwargs – Additional keyword arguments passed to plot_func (optional).

Returns:

A single uint8 Tensor of shape (B, ?, ?, 4), containing the B plot images, each of which is computed by plot_func, where B equals batch_size, the number of batch elements in the each tensor from in_tensors, or max_outputs (whichever is smaller).

  • plot_func – A python function or callable, which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • max_outputs – Max number of batch elements to generate plots for (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).

Returns: A single uint8 Tensor of shape (B, ?, ?, 4), containing the B plot images, each of which is computed by plot_func, where B equals batch_size, the number of batch elements in the each tensor from in_tensors, or max_outputs (whichever is smaller).


2. tfplot.figure

Figure utilities.

tfplot.figure.to_array(fig)[source]

Convert a matplotlib figure fig into a 3D numpy array.

Example

代码语言:javascript复制
>>> fig, ax = tfplot.subplots(figsize=(4, 4))
>>> # draw whatever, e.g. ax.text(0.5, 0.5, "text")

>>> im = to_array(fig)   # ndarray [288, 288, 4]

Parameters:

fig – A matplotlib.figure.Figure object.

Returns:

A numpy ndarray of shape (?, ?, 4), containing an RGB-A image of the figure.

tfplot.figure.to_summary(fig, tag)[source]

Convert a matplotlib figure fig into a TensorFlow Summary object that can be directly fed into Summary.FileWriter.

Example

代码语言:javascript复制
>>> fig, ax = ...    # (as above)
>>> summary = to_summary(fig, tag='MyFigure/image')

>>> type(summary)
tensorflow.core.framework.summary_pb2.Summary
>>> summary_writer.add_summary(summary, global_step=global_step)

Parameters:

fig – A matplotlib.figure.Figure object. tag (string) – The tag name of the created summary.

Returns:

A TensorFlow Summary protobuf object containing the plot image as a image summary.

  • fig – A matplotlib.figure.Figure object.
  • tag (string) – The tag name of the created summary.

Returns: A TensorFlow Summary protobuf object containing the plot image as a image summary.

tfplot.figure.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)[source]

Create a figure and a set of subplots, as in pyplot.subplots().

It works almost similar to pyplot.subplots(), but differ from it in that it does not involve any side effect as pyplot does (e.g. modifying thread states such as current figure or current subplot).

(docstrings inherited from matplotlib.pyplot.subplots)

Parameters:

ncols (nrows,) – Number of rows/columns of the subplot grid. sharey (sharex,) – Controls sharing of properties among x (sharex) or y (sharey) axes: True or ‘all’: x- or y-axis will be shared among all subplots. False or ‘none’: each subplot x- or y-axis will be independent. ’row’: each subplot row will share an x- or y-axis. ’col’: each subplot column will share an x- or y-axis. When subplots have a shared x-axis along a column, only the x tick labels of the bottom subplot are created. Similarly, when subplots have a shared y-axis along a row, only the y tick labels of the first column subplot are created. To later turn other subplots’ ticklabels on, use tick_params(). squeeze (bool, optional, default: True) – If True, extra dimensions are squeezed out from the returned array of Axes: if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar. for Nx1 or 1xM subplots, the returned object is a 1D numpy object array of Axes objects. for NxM, subplots with N>1 and M>1 are returned as a 2D array. If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1. subplot_kw (dict, optional) – Dict with keywords passed to the add_subplot() call used to create each subplot. gridspec_kw (dict, optional) – Dict with keywords passed to the GridSpec constructor used to create the grid the subplots are placed on. **fig_kw – All additional keyword arguments are passed to the figure() call.

Returns:

fig (matplotlib.figure.Figure object) ax (Axes object or array of Axes objects.) – ax can be either a single matplotlib.axes.Axes object or an array of Axes objects if more than one subplot was created. The dimensions of the resulting array can be controlled with the squeeze keyword, see above.

  • ncols (nrows,) – Number of rows/columns of the subplot grid.
  • sharey (sharex,) – Controls sharing of properties among x (sharex) or y (sharey) axes:
    • True or ‘all’: x- or y-axis will be shared among all subplots.
    • False or ‘none’: each subplot x- or y-axis will be independent.
    • ’row’: each subplot row will share an x- or y-axis.
    • ’col’: each subplot column will share an x- or y-axis.

    When subplots have a shared x-axis along a column, only the x tick labels of the bottom subplot are created. Similarly, when subplots have a shared y-axis along a row, only the y tick labels of the first column subplot are created. To later turn other subplots’ ticklabels on, use tick_params().

  • squeeze (bool, optional, default: True) –
    • If True, extra dimensions are squeezed out from the returned array of Axes:
      • if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.
      • for Nx1 or 1xM subplots, the returned object is a 1D numpy object array of Axes objects.
      • for NxM, subplots with N>1 and M>1 are returned as a 2D array.
    • If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
  • subplot_kw (dict, optional) – Dict with keywords passed to the add_subplot() call used to create each subplot.
  • gridspec_kw (dict, optional) – Dict with keywords passed to the GridSpec constructor used to create the grid the subplots are placed on.
  • **fig_kw – All additional keyword arguments are passed to the figure() call.

Returns:

  • fig (matplotlib.figure.Figure object)
  • ax (Axes object or array of Axes objects.) – ax can be either a single matplotlib.axes.Axes object or an array of Axes objects if more than one subplot was created. The dimensions of the resulting array can be controlled with the squeeze keyword, see above.

Examples

First create some toy data:

代码语言:javascript复制
>>> x = np.linspace(0, 2*np.pi, 400)
>>> y = np.sin(x**2)

Creates just a figure and only one subplot

代码语言:javascript复制
>>> fig, ax = tfplot.subplots()
>>> ax.plot(x, y)
>>> ax.set_title('Simple plot')

Creates two subplots and unpacks the output array immediately

代码语言:javascript复制
>>> f, (ax1, ax2) = tfplot.subplots(1, 2, sharey=True)
>>> ax1.plot(x, y)
>>> ax1.set_title('Sharing Y axis')
>>> ax2.scatter(x, y)

Creates four polar axes, and accesses them through the returned array

代码语言:javascript复制
>>> fig, axes = tfplot.subplots(2, 2, subplot_kw=dict(polar=True))
>>> axes[0, 0].plot(x, y)
>>> axes[1, 1].scatter(x, y)

Share a X axis with each column of subplots

代码语言:javascript复制
>>> tfplot.subplots(2, 2, sharex='col')

Share a Y axis with each row of subplots

代码语言:javascript复制
>>> tfplot.subplots(2, 2, sharey='row')

Share both X and Y axes with all subplots

代码语言:javascript复制
>>> tfplot.subplots(2, 2, sharex='all', sharey='all')

Note that this is the same as

代码语言:javascript复制
>>> tfplot.subplots(2, 2, sharex=True, sharey=True)

3. tfplot.contrib

Some predefined plot functions.

tfplot.contrib.probmap(*args, **kwargs_call)[source]

Display a heatmap in color. The resulting op will be a RGBA image Tensor.

Parameters:

x – A 2-D image-like tensor to draw. cmap – Matplotlib colormap. Defaults ‘jet’ axis – If True (default), x-axis and y-axis will appear. colorbar – If True (default), a colorbar will be placed on the right. vmin – A scalar. Minimum value of the range. See matplotlib.axes.Axes.imshow. vmax – A scalar. Maximum value of the range. See matplotlib.axes.Axes.imshow.

Returns:

A uint8 Tensor of shape (?, ?, 4) containing the resulting plot.

  • x – A 2-D image-like tensor to draw.
  • cmap – Matplotlib colormap. Defaults ‘jet’
  • axis – If True (default), x-axis and y-axis will appear.
  • colorbar – If True (default), a colorbar will be placed on the right.
  • vmin – A scalar. Minimum value of the range. See matplotlib.axes.Axes.imshow.
  • vmax – A scalar. Maximum value of the range. See matplotlib.axes.Axes.imshow.

Returns: A uint8 Tensor of shape (?, ?, 4) containing the resulting plot.

tfplot.contrib.probmap_simple(x, **kwargs)[source]

Display a heatmap in color, but only displays the image content. The resulting op will be a RGBA image Tensor.

It reduces to probmap having colorbar and axis off. See the documentation of probmap for available arguments.

tfplot.contrib.batch(func)[source]

Make an autowrapped plot function (… -> RGBA tf.Tensor) work in a batch manner.

Example

代码语言:javascript复制
>>> p
Tensor("p:0", shape=(batch_size, 16, 16, 4), dtype=uint8)
>>> tfplot.contrib.batch(tfplot.contrib.probmap)(p)
Tensor("probmap/PlotImages:0", shape=(batch_size, ?, ?, 4), dtype=uint8)

4. tfplot.summary

Summary Op utilities.

tfplot.summary.wrap(plot_func, _sentinel=None, batch=False, name=None, **kwargs)[source]

Wrap a plot function as a TensorFlow summary builder. It will return a python function that creates a TensorFlow op which evaluates to Summary protocol buffer with image.

The resulting function (say summary_wrapped) will have the following signature:

代码语言:javascript复制
summary_wrapped(name, tensor, # [more input tensors ...],
                max_outputs=3, collections=None)

Examples

Given a plot function which returns a matplotlib Figure,

代码语言:javascript复制
>>> def figure_heatmap(data, cmap='jet'):
>>>     fig, ax = tfplot.subplots()
>>>     ax.imshow(data, cmap=cmap)
>>>     return fig

we can wrap it as a summary builder function:

代码语言:javascript复制
>>> summary_heatmap = tfplot.summary.wrap(figure_heatmap, batch=True)

Now, when building your computation graph, call it to build summary ops like tf.summary.image:

代码语言:javascript复制
>>> heatmap_tensor
<tf.Tensor 'heatmap_tensor:0' shape=(16, 128, 128) dtype=float32>
>>>
>>> summary_heatmap("heatmap/original", heatmap_tensor)
>>> summary_heatmap("heatmap/cmap_gray", heatmap_tensor, cmap=gray)
>>> summary_heatmap("heatmap/no_default_collections", heatmap_tensor, collections=[])

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the plot op (optional). If not given, the name of plot_func will be used. kwargs – Optional keyword arguments that will be passed by default to plot().

Returns:

A python function that will create a TensorFlow summary operation, passing the provided arguments into plot op.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the plot op (optional). If not given, the name of plot_func will be used.
  • kwargs – Optional keyword arguments that will be passed by default to plot().

Returns: A python function that will create a TensorFlow summary operation, passing the provided arguments into plot op.

tfplot.summary.plot(name, plot_func, in_tensors, collections=None, **kwargs)[source]

Create a TensorFlow op that outpus a Summary protocol buffer, to which a single plot operation is executed (i.e. image summary).

Basically, it is a one-liner wrapper of tfplot.ops.plot() and tf.summary.image() calls.

The generated Summary object contains single image summary value of the image of the plot drawn.

Parameters:

name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot(). in_tensors – A list of Tensor objects, as in plot(). collections – Optional list of ops.GraphKeys. The collections to add the summary to. Defaults to [_ops.GraphKeys.SUMMARIES]. kwargs – Optional keyword arguments passed to plot().

Returns:

A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

  • name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard.
  • plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot().
  • in_tensors – A list of Tensor objects, as in plot().
  • collections – Optional list of ops.GraphKeys. The collections to add the summary to. Defaults to [_ops.GraphKeys.SUMMARIES].
  • kwargs – Optional keyword arguments passed to plot().

Returns: A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

tfplot.summary.plot_many(name, plot_func, in_tensors, max_outputs=3, collections=None, **kwargs)[source]

Create a TensorFlow op that outputs a Summary protocol buffer, where plots could be drawn in a batch manner. This is a batch version of tfplot.summary.plot().

Specifically, all the input tensors in_tensors to plot_func is assumed to have the same batch size. Tensors corresponding to a single batch element will be passed to plot_func as input.

The resulting Summary contains multiple (up to max_outputs) image summary values, each of which contains a plot rendered by plot_func.

Parameters:

name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot(). in_tensors – A list of Tensor objects, the input to plot_func but each in a batch. max_outputs – Max number of batch elements to generate plots for. collections – Optional list of ops.GraphKeys. The collections to add the sumamry to. Defaults to [_ops.GraphKeys.SUMMARIES]. kwargs – Optional keyword arguments passed to plot().

Returns:

A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

  • name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard.
  • plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot().
  • in_tensors – A list of Tensor objects, the input to plot_func but each in a batch.
  • max_outputs – Max number of batch elements to generate plots for.
  • collections – Optional list of ops.GraphKeys. The collections to add the sumamry to. Defaults to [_ops.GraphKeys.SUMMARIES].
  • kwargs – Optional keyword arguments passed to plot().

Returns: A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

0 人点赞