Python/NumPy 中的多维广播 - 或 `numpy.squeeze()` 的逆运算
Multidimensional Broadcasting in Python / NumPy - or inverse of `numpy.squeeze()`
当对 np.broadcast_to()
的简单调用会失败时,将两个数组一起广播的最佳方式是什么?
考虑以下示例:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.arange(3 * 5).reshape((3, 5))
arr1 + arr2
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5)
arr2_ = np.broadcast_to(arr2, arr1.shape)
# ValueError: operands could not be broadcast together with remapped shapes
arr2_ = arr2.reshape((1, 3, 1, 5, 1))
arr1 + arr2
# now this works because the singletons trigger the automatic broadcast
这仅在我手动 select 自动广播将要运行的形状时才有效。
自动执行此操作的最有效方法是什么?
除了在巧妙构造的 可广播 形状上重塑外,还有其他方法吗?
注意与 np.squeeze()
的关系:这将通过删除单例执行逆操作。所以我需要的是某种 np.squeeze()
逆。
官方 documentation(从 NumPy 1.13.0 开始建议 np.squeeze()
的倒数是 np.expand_dim()
,但这远不如我灵活需要它,实际上 np.expand_dim()
大致相当于 np.reshape(array, shape + (1,))
或 array[:, None]
.
这个问题也与接受的 keepdims
关键字有关,例如sum
:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
# not using `keepdims`
arr2 = np.sum(arr1, (0, 2, 4))
arr2.shape
# : (3, 5)
arr1 + arr2
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5)
# now using `keepdims`
arr2 = np.sum(arr1, (0, 2, 4), keepdims=True)
arr2.shape
# : (1, 3, 1, 5, 1)
arr1 + arr2
# now this works because it has the correct shape
编辑:显然,在 np.newaxis
或 keepdims
机制是合适选择的情况下,不需要 unsqueeze()
函数。
然而,有些用例中 none 是一个选项。
例如,考虑在 axis
指定的任意维数上 numpy.average()
中实施的加权平均值的情况。
现在 weights
参数必须与输入具有相同的形状。
但是,weights
无需指定非缩减维度的权重,因为它们只是重复,NumPy 的广播机制会适当地处理它们。
所以如果我们想拥有这样的功能,我们需要编写类似的代码(为了简单起见,省略了一些一致性检查):
def weighted_average(arr, weights=None, axis=None):
if weights is not None and weights.shape != arr.shape:
weights = unsqueeze(weights, ...)
weights = np.zeros_like(arr) + weights
result = np.sum(arr * weights, axis=axis)
result /= np.sum(weights, axis=axis)
return result
或者,等价地:
def weighted_average(arr, weights=None, axis=None):
if weights is not None and weights.shape != arr.shape:
weights = unsqueeze(weights, ...)
weights = np.zeros_like(arr) + weights
return np.average(arr, weights, axis)
在两者中的任何一个中,都不可能用类似 weights[:, np.newaxis]
的语句替换 unsqueeze()
,因为我们事先不知道哪里需要新轴,也不能使用keepdims
sum
的功能,因为代码将在 arr * weights
.
处失败
如果 np.expand_dims()
支持其 axis
参数的可迭代整数,那么这种情况可以相对较好地处理,但从 NumPy 1.13.0 开始不支持。
我实现这一点的方法是定义以下 unsqueezing()
函数来处理可以自动完成的情况,并在输入可能不明确时发出警告(例如,当源形状的某些源元素可能匹配目标形状的多个元素):
def unsqueezing(
source_shape,
target_shape):
"""
Generate a broadcasting-compatible shape.
The resulting shape contains *singletons* (i.e. `1`) for non-matching dims.
Assumes all elements of the source shape are contained in the target shape
(excepts for singletons) in the correct order.
Warning! The generated shape may not be unique if some of the elements
from the source shape are present multiple timesin the target shape.
Args:
source_shape (Sequence): The source shape.
target_shape (Sequence): The target shape.
Returns:
shape (tuple): The broadcast-safe shape.
Raises:
ValueError: if elements of `source_shape` are not in `target_shape`.
Examples:
For non-repeating elements, `unsqueezing()` is always well-defined:
>>> unsqueezing((2, 3), (2, 3, 4))
(2, 3, 1)
>>> unsqueezing((3, 4), (2, 3, 4))
(1, 3, 4)
>>> unsqueezing((3, 5), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
>>> unsqueezing((1, 3, 5, 1), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
If there is nothing to unsqueeze, the `source_shape` is returned:
>>> unsqueezing((1, 3, 1, 5, 1), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
>>> unsqueezing((2, 3), (2, 3))
(2, 3)
If some elements in `source_shape` are repeating in `target_shape`,
a user warning will be issued:
>>> unsqueezing((2, 2), (2, 2, 2, 2, 2))
(2, 2, 1, 1, 1)
>>> unsqueezing((2, 2), (2, 3, 2, 2, 2))
(2, 1, 2, 1, 1)
If some elements of `source_shape` are not presente in `target_shape`,
an error is raised.
>>> unsqueezing((2, 3), (2, 2, 2, 2, 2))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (2, 3) -> (2, 2, 2, 2, 2)
>>> unsqueezing((5, 3), (2, 3, 4, 5, 6))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (5, 3) -> (2, 3, 4, 5, 6)
"""
shape = []
j = 0
for i, dim in enumerate(target_shape):
if j < len(source_shape):
shape.append(dim if dim == source_shape[j] else 1)
if i + 1 < len(target_shape) and dim == source_shape[j] \
and dim != 1 and dim in target_shape[i + 1:]:
text = ('Multiple positions (e.g. {} and {})'
' for source shape element {}.'.format(
i, target_shape[i + 1:].index(dim) + (i + 1), dim))
warnings.warn(text)
if dim == source_shape[j] or source_shape[j] == 1:
j += 1
else:
shape.append(1)
if j < len(source_shape):
raise ValueError(
'Target shape must contain all source shape elements'
' (in correct order). {} -> {}'.format(source_shape, target_shape))
return tuple(shape)
这可用于将 unsqueeze()
定义为 np.squeeze()
的更灵活的逆函数,而 np.expand_dims()
一次只能附加一个单例:
def unsqueeze(
arr,
axis=None,
shape=None,
reverse=False):
"""
Add singletons to the shape of an array to broadcast-match a given shape.
In some sense, this function implements the inverse of `numpy.squeeze()`.
Args:
arr (np.ndarray): The input array.
axis (int|Iterable|None): Axis or axes in which to operate.
If None, a valid set axis is generated from `shape` when this is
defined and the shape can be matched by `unsqueezing()`.
If int or Iterable, specified how singletons are added.
This depends on the value of `reverse`.
If `shape` is not None, the `axis` and `shape` parameters must be
consistent.
Values must be in the range [-(ndim+1), ndim+1]
At least one of `axis` and `shape` must be specified.
shape (int|Iterable|None): The target shape.
If None, no safety checks are performed.
If int, this is interpreted as the number of dimensions of the
output array.
If Iterable, the result must be broadcastable to an array with the
specified shape.
If `axis` is not None, the `axis` and `shape` parameters must be
consistent.
At least one of `axis` and `shape` must be specified.
reverse (bool): Interpret `axis` parameter as its complementary.
If True, the dims of the input array are placed at the positions
indicated by `axis`, and singletons are placed everywherelse and
the `axis` length must be equal to the number of dimensions of the
input array; the `shape` parameter cannot be `None`.
If False, the singletons are added at the position(s) specified by
`axis`.
If `axis` is None, `reverse` has no effect.
Returns:
arr (np.ndarray): The reshaped array.
Raises:
ValueError: if the `arr` shape cannot be reshaped correctly.
Examples:
Let's define some input array `arr`:
>>> arr = np.arange(2 * 3 * 4).reshape((2, 3, 4))
>>> arr.shape
(2, 3, 4)
A call to `unsqueeze()` can be reversed by `np.squeeze()`:
>>> arr_ = unsqueeze(arr, (0, 2, 4))
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
>>> arr = np.squeeze(arr_, (0, 2, 4))
>>> arr.shape
(2, 3, 4)
The order of the axes does not matter:
>>> arr_ = unsqueeze(arr, (0, 4, 2))
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
If `shape` is an int, `axis` must be consistent with it:
>>> arr_ = unsqueeze(arr, (0, 2, 4), 6)
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
>>> arr_ = unsqueeze(arr, (0, 2, 4), 7)
Traceback (most recent call last):
...
ValueError: Incompatible `[0, 2, 4]` axis and `7` shape for array of\
shape (2, 3, 4)
It is possible to reverse the meaning to `axis` to add singletons
everywhere except where specified (but requires `shape` to be defined
and the length of `axis` must match the array dims):
>>> arr_ = unsqueeze(arr, (0, 2, 4), 10, True)
>>> arr_.shape
(2, 1, 3, 1, 4, 1, 1, 1, 1, 1)
>>> arr_ = unsqueeze(arr, (0, 2, 4), reverse=True)
Traceback (most recent call last):
...
ValueError: When `reverse` is True, `shape` cannot be None.
>>> arr_ = unsqueeze(arr, (0, 2), 10, True)
Traceback (most recent call last):
...
ValueError: When `reverse` is True, the length of axis (2) must match\
the num of dims of array (3).
Axes values must be valid:
>>> arr_ = unsqueeze(arr, 0)
>>> arr_.shape
(1, 2, 3, 4)
>>> arr_ = unsqueeze(arr, 3)
>>> arr_.shape
(2, 3, 4, 1)
>>> arr_ = unsqueeze(arr, -1)
>>> arr_.shape
(2, 3, 4, 1)
>>> arr_ = unsqueeze(arr, -4)
>>> arr_.shape
(1, 2, 3, 4)
>>> arr_ = unsqueeze(arr, 10)
Traceback (most recent call last):
...
ValueError: Axis (10,) out of range.
If `shape` is specified, `axis` can be omitted (USE WITH CARE!) or its
value is used for addiotional safety checks:
>>> arr_ = unsqueeze(arr, shape=(2, 3, 4, 5, 6))
>>> arr_.shape
(2, 3, 4, 1, 1)
>>> arr_ = unsqueeze(
... arr, (3, 6, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True)
>>> arr_.shape
(1, 1, 1, 2, 1, 1, 3, 1, 4, 1, 1)
>>> arr_ = unsqueeze(
... arr, (3, 7, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True)
Traceback (most recent call last):
...
ValueError: New shape [1, 1, 1, 2, 1, 1, 1, 3, 4, 1, 1] cannot be\
broadcasted to shape (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6)
>>> arr = unsqueeze(arr, shape=(2, 5, 3, 7, 2, 4, 5, 6))
>>> arr.shape
(2, 1, 3, 1, 1, 4, 1, 1)
>>> arr = np.squeeze(arr)
>>> arr.shape
(2, 3, 4)
>>> arr = unsqueeze(arr, shape=(5, 3, 7, 2, 4, 5, 6))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (2, 3, 4) -> (5, 3, 7, 2, 4, 5, 6)
The behavior is consistent with other NumPy functions and the
`keepdims` mechanism:
>>> axis = (0, 2, 4)
>>> arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
>>> arr2 = np.sum(arr1, axis, keepdims=True)
>>> arr2.shape
(1, 3, 1, 5, 1)
>>> arr3 = np.sum(arr1, axis)
>>> arr3.shape
(3, 5)
>>> arr3 = unsqueeze(arr3, axis)
>>> arr3.shape
(1, 3, 1, 5, 1)
>>> np.all(arr2 == arr3)
True
"""
# calculate `new_shape`
if axis is None and shape is None:
raise ValueError(
'At least one of `axis` and `shape` parameters must be specified.')
elif axis is None and shape is not None:
new_shape = unsqueezing(arr.shape, shape)
elif axis is not None:
if isinstance(axis, int):
axis = (axis,)
# calculate the dim of the result
if shape is not None:
if isinstance(shape, int):
ndim = shape
else: # shape is a sequence
ndim = len(shape)
elif not reverse:
ndim = len(axis) + arr.ndim
else:
raise ValueError('When `reverse` is True, `shape` cannot be None.')
# check that axis is properly constructed
if any([ax < -ndim - 1 or ax > ndim + 1 for ax in axis]):
raise ValueError('Axis {} out of range.'.format(axis))
# normalize axis using `ndim`
axis = sorted([ax % ndim for ax in axis])
# manage reverse mode
if reverse:
if len(axis) == arr.ndim:
axis = [i for i in range(ndim) if i not in axis]
else:
raise ValueError(
'When `reverse` is True, the length of axis ({})'
' must match the num of dims of array ({}).'.format(
len(axis), arr.ndim))
elif len(axis) + arr.ndim != ndim:
raise ValueError(
'Incompatible `{}` axis and `{}` shape'
' for array of shape {}'.format(axis, shape, arr.shape))
# generate the new shape from axis, ndim and shape
new_shape = []
i, j = 0, 0
for l in range(ndim):
if i < len(axis) and l == axis[i] or j >= arr.ndim:
new_shape.append(1)
i += 1
else:
new_shape.append(arr.shape[j])
j += 1
# check that `new_shape` is consistent with `shape`
if shape is not None:
if isinstance(shape, int):
if len(new_shape) != ndim:
raise ValueError(
'Length of new shape {} does not match '
'expected length ({}).'.format(len(new_shape), ndim))
else:
if not all([new_dim == 1 or new_dim == dim
for new_dim, dim in zip(new_shape, shape)]):
raise ValueError(
'New shape {} cannot be broadcasted to shape {}'.format(
new_shape, shape))
return arr.reshape(new_shape)
利用这些,可以写成:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.arange(3 * 5).reshape((3, 5))
arr3 = unsqueeze(arr2, (0, 2, 4))
arr1 + arr3
# now this works because it has the correct shape
arr3 = unsqueeze(arr2, shape=arr1.shape)
arr1 + arr3
# this also works because the shape can be expanded unambiguously
所以现在可以进行动态广播,这与 keepdims
:
的行为一致
import numpy as np
axis = (0, 2, 4)
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.sum(arr1, axis, keepdims=True)
arr3 = np.sum(arr1, axis)
arr3 = unsqueeze(arr3, axis)
np.all(arr2 == arr3)
# : True
实际上,这扩展了 np.expand_dims()
以处理更复杂的场景。
对此代码的改进显然非常受欢迎。
当对 np.broadcast_to()
的简单调用会失败时,将两个数组一起广播的最佳方式是什么?
考虑以下示例:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.arange(3 * 5).reshape((3, 5))
arr1 + arr2
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5)
arr2_ = np.broadcast_to(arr2, arr1.shape)
# ValueError: operands could not be broadcast together with remapped shapes
arr2_ = arr2.reshape((1, 3, 1, 5, 1))
arr1 + arr2
# now this works because the singletons trigger the automatic broadcast
这仅在我手动 select 自动广播将要运行的形状时才有效。 自动执行此操作的最有效方法是什么? 除了在巧妙构造的 可广播 形状上重塑外,还有其他方法吗?
注意与 np.squeeze()
的关系:这将通过删除单例执行逆操作。所以我需要的是某种 np.squeeze()
逆。
官方 documentation(从 NumPy 1.13.0 开始建议 np.squeeze()
的倒数是 np.expand_dim()
,但这远不如我灵活需要它,实际上 np.expand_dim()
大致相当于 np.reshape(array, shape + (1,))
或 array[:, None]
.
这个问题也与接受的 keepdims
关键字有关,例如sum
:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
# not using `keepdims`
arr2 = np.sum(arr1, (0, 2, 4))
arr2.shape
# : (3, 5)
arr1 + arr2
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5)
# now using `keepdims`
arr2 = np.sum(arr1, (0, 2, 4), keepdims=True)
arr2.shape
# : (1, 3, 1, 5, 1)
arr1 + arr2
# now this works because it has the correct shape
编辑:显然,在 np.newaxis
或 keepdims
机制是合适选择的情况下,不需要 unsqueeze()
函数。
然而,有些用例中 none 是一个选项。
例如,考虑在 axis
指定的任意维数上 numpy.average()
中实施的加权平均值的情况。
现在 weights
参数必须与输入具有相同的形状。
但是,weights
无需指定非缩减维度的权重,因为它们只是重复,NumPy 的广播机制会适当地处理它们。
所以如果我们想拥有这样的功能,我们需要编写类似的代码(为了简单起见,省略了一些一致性检查):
def weighted_average(arr, weights=None, axis=None):
if weights is not None and weights.shape != arr.shape:
weights = unsqueeze(weights, ...)
weights = np.zeros_like(arr) + weights
result = np.sum(arr * weights, axis=axis)
result /= np.sum(weights, axis=axis)
return result
或者,等价地:
def weighted_average(arr, weights=None, axis=None):
if weights is not None and weights.shape != arr.shape:
weights = unsqueeze(weights, ...)
weights = np.zeros_like(arr) + weights
return np.average(arr, weights, axis)
在两者中的任何一个中,都不可能用类似 weights[:, np.newaxis]
的语句替换 unsqueeze()
,因为我们事先不知道哪里需要新轴,也不能使用keepdims
sum
的功能,因为代码将在 arr * weights
.
如果 np.expand_dims()
支持其 axis
参数的可迭代整数,那么这种情况可以相对较好地处理,但从 NumPy 1.13.0 开始不支持。
我实现这一点的方法是定义以下 unsqueezing()
函数来处理可以自动完成的情况,并在输入可能不明确时发出警告(例如,当源形状的某些源元素可能匹配目标形状的多个元素):
def unsqueezing(
source_shape,
target_shape):
"""
Generate a broadcasting-compatible shape.
The resulting shape contains *singletons* (i.e. `1`) for non-matching dims.
Assumes all elements of the source shape are contained in the target shape
(excepts for singletons) in the correct order.
Warning! The generated shape may not be unique if some of the elements
from the source shape are present multiple timesin the target shape.
Args:
source_shape (Sequence): The source shape.
target_shape (Sequence): The target shape.
Returns:
shape (tuple): The broadcast-safe shape.
Raises:
ValueError: if elements of `source_shape` are not in `target_shape`.
Examples:
For non-repeating elements, `unsqueezing()` is always well-defined:
>>> unsqueezing((2, 3), (2, 3, 4))
(2, 3, 1)
>>> unsqueezing((3, 4), (2, 3, 4))
(1, 3, 4)
>>> unsqueezing((3, 5), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
>>> unsqueezing((1, 3, 5, 1), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
If there is nothing to unsqueeze, the `source_shape` is returned:
>>> unsqueezing((1, 3, 1, 5, 1), (2, 3, 4, 5, 6))
(1, 3, 1, 5, 1)
>>> unsqueezing((2, 3), (2, 3))
(2, 3)
If some elements in `source_shape` are repeating in `target_shape`,
a user warning will be issued:
>>> unsqueezing((2, 2), (2, 2, 2, 2, 2))
(2, 2, 1, 1, 1)
>>> unsqueezing((2, 2), (2, 3, 2, 2, 2))
(2, 1, 2, 1, 1)
If some elements of `source_shape` are not presente in `target_shape`,
an error is raised.
>>> unsqueezing((2, 3), (2, 2, 2, 2, 2))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (2, 3) -> (2, 2, 2, 2, 2)
>>> unsqueezing((5, 3), (2, 3, 4, 5, 6))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (5, 3) -> (2, 3, 4, 5, 6)
"""
shape = []
j = 0
for i, dim in enumerate(target_shape):
if j < len(source_shape):
shape.append(dim if dim == source_shape[j] else 1)
if i + 1 < len(target_shape) and dim == source_shape[j] \
and dim != 1 and dim in target_shape[i + 1:]:
text = ('Multiple positions (e.g. {} and {})'
' for source shape element {}.'.format(
i, target_shape[i + 1:].index(dim) + (i + 1), dim))
warnings.warn(text)
if dim == source_shape[j] or source_shape[j] == 1:
j += 1
else:
shape.append(1)
if j < len(source_shape):
raise ValueError(
'Target shape must contain all source shape elements'
' (in correct order). {} -> {}'.format(source_shape, target_shape))
return tuple(shape)
这可用于将 unsqueeze()
定义为 np.squeeze()
的更灵活的逆函数,而 np.expand_dims()
一次只能附加一个单例:
def unsqueeze(
arr,
axis=None,
shape=None,
reverse=False):
"""
Add singletons to the shape of an array to broadcast-match a given shape.
In some sense, this function implements the inverse of `numpy.squeeze()`.
Args:
arr (np.ndarray): The input array.
axis (int|Iterable|None): Axis or axes in which to operate.
If None, a valid set axis is generated from `shape` when this is
defined and the shape can be matched by `unsqueezing()`.
If int or Iterable, specified how singletons are added.
This depends on the value of `reverse`.
If `shape` is not None, the `axis` and `shape` parameters must be
consistent.
Values must be in the range [-(ndim+1), ndim+1]
At least one of `axis` and `shape` must be specified.
shape (int|Iterable|None): The target shape.
If None, no safety checks are performed.
If int, this is interpreted as the number of dimensions of the
output array.
If Iterable, the result must be broadcastable to an array with the
specified shape.
If `axis` is not None, the `axis` and `shape` parameters must be
consistent.
At least one of `axis` and `shape` must be specified.
reverse (bool): Interpret `axis` parameter as its complementary.
If True, the dims of the input array are placed at the positions
indicated by `axis`, and singletons are placed everywherelse and
the `axis` length must be equal to the number of dimensions of the
input array; the `shape` parameter cannot be `None`.
If False, the singletons are added at the position(s) specified by
`axis`.
If `axis` is None, `reverse` has no effect.
Returns:
arr (np.ndarray): The reshaped array.
Raises:
ValueError: if the `arr` shape cannot be reshaped correctly.
Examples:
Let's define some input array `arr`:
>>> arr = np.arange(2 * 3 * 4).reshape((2, 3, 4))
>>> arr.shape
(2, 3, 4)
A call to `unsqueeze()` can be reversed by `np.squeeze()`:
>>> arr_ = unsqueeze(arr, (0, 2, 4))
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
>>> arr = np.squeeze(arr_, (0, 2, 4))
>>> arr.shape
(2, 3, 4)
The order of the axes does not matter:
>>> arr_ = unsqueeze(arr, (0, 4, 2))
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
If `shape` is an int, `axis` must be consistent with it:
>>> arr_ = unsqueeze(arr, (0, 2, 4), 6)
>>> arr_.shape
(1, 2, 1, 3, 1, 4)
>>> arr_ = unsqueeze(arr, (0, 2, 4), 7)
Traceback (most recent call last):
...
ValueError: Incompatible `[0, 2, 4]` axis and `7` shape for array of\
shape (2, 3, 4)
It is possible to reverse the meaning to `axis` to add singletons
everywhere except where specified (but requires `shape` to be defined
and the length of `axis` must match the array dims):
>>> arr_ = unsqueeze(arr, (0, 2, 4), 10, True)
>>> arr_.shape
(2, 1, 3, 1, 4, 1, 1, 1, 1, 1)
>>> arr_ = unsqueeze(arr, (0, 2, 4), reverse=True)
Traceback (most recent call last):
...
ValueError: When `reverse` is True, `shape` cannot be None.
>>> arr_ = unsqueeze(arr, (0, 2), 10, True)
Traceback (most recent call last):
...
ValueError: When `reverse` is True, the length of axis (2) must match\
the num of dims of array (3).
Axes values must be valid:
>>> arr_ = unsqueeze(arr, 0)
>>> arr_.shape
(1, 2, 3, 4)
>>> arr_ = unsqueeze(arr, 3)
>>> arr_.shape
(2, 3, 4, 1)
>>> arr_ = unsqueeze(arr, -1)
>>> arr_.shape
(2, 3, 4, 1)
>>> arr_ = unsqueeze(arr, -4)
>>> arr_.shape
(1, 2, 3, 4)
>>> arr_ = unsqueeze(arr, 10)
Traceback (most recent call last):
...
ValueError: Axis (10,) out of range.
If `shape` is specified, `axis` can be omitted (USE WITH CARE!) or its
value is used for addiotional safety checks:
>>> arr_ = unsqueeze(arr, shape=(2, 3, 4, 5, 6))
>>> arr_.shape
(2, 3, 4, 1, 1)
>>> arr_ = unsqueeze(
... arr, (3, 6, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True)
>>> arr_.shape
(1, 1, 1, 2, 1, 1, 3, 1, 4, 1, 1)
>>> arr_ = unsqueeze(
... arr, (3, 7, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True)
Traceback (most recent call last):
...
ValueError: New shape [1, 1, 1, 2, 1, 1, 1, 3, 4, 1, 1] cannot be\
broadcasted to shape (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6)
>>> arr = unsqueeze(arr, shape=(2, 5, 3, 7, 2, 4, 5, 6))
>>> arr.shape
(2, 1, 3, 1, 1, 4, 1, 1)
>>> arr = np.squeeze(arr)
>>> arr.shape
(2, 3, 4)
>>> arr = unsqueeze(arr, shape=(5, 3, 7, 2, 4, 5, 6))
Traceback (most recent call last):
...
ValueError: Target shape must contain all source shape elements\
(in correct order). (2, 3, 4) -> (5, 3, 7, 2, 4, 5, 6)
The behavior is consistent with other NumPy functions and the
`keepdims` mechanism:
>>> axis = (0, 2, 4)
>>> arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
>>> arr2 = np.sum(arr1, axis, keepdims=True)
>>> arr2.shape
(1, 3, 1, 5, 1)
>>> arr3 = np.sum(arr1, axis)
>>> arr3.shape
(3, 5)
>>> arr3 = unsqueeze(arr3, axis)
>>> arr3.shape
(1, 3, 1, 5, 1)
>>> np.all(arr2 == arr3)
True
"""
# calculate `new_shape`
if axis is None and shape is None:
raise ValueError(
'At least one of `axis` and `shape` parameters must be specified.')
elif axis is None and shape is not None:
new_shape = unsqueezing(arr.shape, shape)
elif axis is not None:
if isinstance(axis, int):
axis = (axis,)
# calculate the dim of the result
if shape is not None:
if isinstance(shape, int):
ndim = shape
else: # shape is a sequence
ndim = len(shape)
elif not reverse:
ndim = len(axis) + arr.ndim
else:
raise ValueError('When `reverse` is True, `shape` cannot be None.')
# check that axis is properly constructed
if any([ax < -ndim - 1 or ax > ndim + 1 for ax in axis]):
raise ValueError('Axis {} out of range.'.format(axis))
# normalize axis using `ndim`
axis = sorted([ax % ndim for ax in axis])
# manage reverse mode
if reverse:
if len(axis) == arr.ndim:
axis = [i for i in range(ndim) if i not in axis]
else:
raise ValueError(
'When `reverse` is True, the length of axis ({})'
' must match the num of dims of array ({}).'.format(
len(axis), arr.ndim))
elif len(axis) + arr.ndim != ndim:
raise ValueError(
'Incompatible `{}` axis and `{}` shape'
' for array of shape {}'.format(axis, shape, arr.shape))
# generate the new shape from axis, ndim and shape
new_shape = []
i, j = 0, 0
for l in range(ndim):
if i < len(axis) and l == axis[i] or j >= arr.ndim:
new_shape.append(1)
i += 1
else:
new_shape.append(arr.shape[j])
j += 1
# check that `new_shape` is consistent with `shape`
if shape is not None:
if isinstance(shape, int):
if len(new_shape) != ndim:
raise ValueError(
'Length of new shape {} does not match '
'expected length ({}).'.format(len(new_shape), ndim))
else:
if not all([new_dim == 1 or new_dim == dim
for new_dim, dim in zip(new_shape, shape)]):
raise ValueError(
'New shape {} cannot be broadcasted to shape {}'.format(
new_shape, shape))
return arr.reshape(new_shape)
利用这些,可以写成:
import numpy as np
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.arange(3 * 5).reshape((3, 5))
arr3 = unsqueeze(arr2, (0, 2, 4))
arr1 + arr3
# now this works because it has the correct shape
arr3 = unsqueeze(arr2, shape=arr1.shape)
arr1 + arr3
# this also works because the shape can be expanded unambiguously
所以现在可以进行动态广播,这与 keepdims
:
import numpy as np
axis = (0, 2, 4)
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
arr2 = np.sum(arr1, axis, keepdims=True)
arr3 = np.sum(arr1, axis)
arr3 = unsqueeze(arr3, axis)
np.all(arr2 == arr3)
# : True
实际上,这扩展了 np.expand_dims()
以处理更复杂的场景。
对此代码的改进显然非常受欢迎。