导入
这篇博客主要想聊一下python列表切片,python列表切片不就是一个slice对象吗?还有什么好聊的呢?的确,对于一维的list来说,一个简单的slice对象传入就可以进行切片,比如array[1:100:2]
这样,但是对于多维度的list来说,我们不能直接使用多个slice来针对不同维度进行切片,因为list并没没有实现这个功能,多维度在list眼中就是简单的list嵌套。而如果你使用过numpy
库,应该知道,对于numpy中的多维数组,我们可以直接针对不同维度进行切片,比如二维数组中使用array[1:10, 3:20]
选取第1到9行,第3到19列作为输出的切片,非常的方便,但是list中二维列表要实现这个效果,则一般需要使用列表生成式,比如[arr[3:20] for arr in array[1:10]
,的确可以选择将list转换成numpy数组进行切片,然后再tolist
即可,这也是很多场景下的优选方案,但是本篇博客要说明的还是为什么?如何实现的?因此下面将实现一个简单的自定义数组类,不依赖任何第三方库,实现类似numpy数组的多维度直接切片的效果。
正文
切片的本质还是索引元素,因此背后调用的都是__getitem__
方法,区别是索引元素传入的是int
索引,而切片传入的是一个slice
对象,而进一步,我们想要实现的类似array[1:10, 3:20]
这种多维度切片,传入的则是slice
对象组成的tuple
元组。因此下一步就很简单,需要对传入的这个slice元组进行处理,第一步就是对slice的不同情况进行考虑,因为slice也可能不是slice,只是一个int
,就像前面说的,还可能是...
省略号,表示剩下所有维度,如果是int
,比如2
,我们就需要转换成等价的slice形式slice(2,3)
,如果是省略号,则使用slice(None)
替代即可,表示不进行该维度的切片。大致处理代码如下:
def _parse_index(self, key):
if isinstance(key, int):
return slice(key, key+1)
elif isinstance(key, slice):
return key
elif key is Ellipsis:
return slice(None)
else:
raise TypeError(f"Unsupported index type: {type(key)}")
传入__getitem__
的参数全部这样处理完之后,还需要进行维度的补全,比如使用了省略号,或者单纯就是切片维度小于实际维度的情况下,我们默认就是后面没写到的维度不做切片,因此需要全部补全成slice(None)
:
slices = []
for k in key:
slices.append(self._parse_index(k))
while len(slices) < len(self._shape):
slices.append(slice(None))
需要说明的是,我们使用的是原始的list列表,因此只能获取第一维度的len
,不能像numpy的数组那样直接通过shape
获取完整的维度信息,因此上面的self._shape
需要我们自己进行计算,计算的思路很简单,使用递归
即可,先用len
获取第一维度,然后对剩下的维度依次使用len
进行递归即可,递归终止条件就是传入的数据不再是list列表了,而是变成一个具体的数据了,比如说int:
def _get_shape(self, data):
if isinstance(data, list):
return [len(data)] + self._get_shape(data[0])
else:
return []
现在我们有了实际的列表数据,有了所有维度的切片数据,接下来就是针对每一个维度进行实际的切片,上面我们已经计算出列表的维度信息,因此是可以直接循环进行处理的,但是相对复杂一些,我们还可以再次利用递归
进行处理,先获取第一个slice,根据这个slice切片对当前维度的数据进行切片,然后递归处理下级维度,利用剩下的slice对当前切片后的list数据继续递归处理,代码如下:
def _slice_nd_list(self, lst, slices):
# 递归终止条件
if len(slices) == 0:
return lst
current_slice = slices
remaining_slices = slices[1:]
# 处理当前维度切片
dim_size = len(lst)
start, stop, step = current_slice[0].indices(dim_size)
sliced = [lst[i] for i in range(start, stop, step)]
# 递归处理下级维度
if remaining_slices:
return [self._slice_nd_list(item, remaining_slices) for item in sliced]
return sliced
到这里就差不多结束了,已经实现了普通list列表的多维度直接切片的效果了,完整的代码:
class SimpleArray:
def __init__(self, data):
self._data = list(data)
self._shape = self._get_shape(self._data)
@property
def shape(self):
return self._shape
def _get_shape(self, data):
if isinstance(data, list):
return [len(data)] + self._get_shape(data[0])
else:
return []
def _slice_nd_list(self, lst, slices):
# 递归终止条件
if len(slices) == 0:
return lst
current_slice = slices
remaining_slices = slices[1:]
# 处理当前维度切片
dim_size = len(lst)
start, stop, step = current_slice[0].indices(dim_size)
sliced = [lst[i] for i in range(start, stop, step)]
# 递归处理下级维度
if remaining_slices:
return [self._slice_nd_list(item, remaining_slices) for item in sliced]
return sliced
def _parse_index(self, key):
if isinstance(key, int):
return slice(key, key+1)
elif isinstance(key, slice):
return key
elif key is Ellipsis:
return slice(None)
else:
raise TypeError(f"Unsupported index type: {type(key)}")
def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
# 处理切片不足的情况(自动补全剩余维度)
slices = []
for k in key:
slices.append(self._parse_index(k))
while len(slices) < len(self._shape):
slices.append(slice(None))
# 生成视图数据
view_data = self._slice_nd_list(self._data, slices)
return SimpleArray(view_data)
def __repr__(self):
return f"SimpleArray({self._data})"
# 测试用例
if __name__ == "__main__":
arr = SimpleArray([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
print(arr[1:3, 0:2]) # 输出: SimpleArray([[5, 6], [9, 10]])
print(arr[0, ...]) # 输出: SimpleArray([[1, 2, 3, 4]])
print(arr[:, 1:3]) # 输出: SimpleArray([[2, 3], [6, 7], [10, 11]])
通过测试用例,我们可以看到效果,和numpy的切片完全一致了,这其实也就是numpy切片底层实现的大概原理,希望这篇博客能够透过这种底层原理让读者更好的了解多维度切片。