关于python多维列表的切片

导入

这篇博客主要想聊一下python列表切片,python列表切片不就是一个slice对象吗?还有什么好聊的呢?的确,对于一维的list来说,一个简单的slice对象传入就可以进行切片,比如array[1:100:2]这样,但是对于多维度的list来说,我们不能直接使用多个slice来针对不同维度进行切片,因为list并没没有实现这个功能,多维度在list眼中就是简单的list嵌套。而如果你使用过numpy库,应该知道,对于numpy中的多维数组,我们可以直接针对不同维度进行切片,比如二维数组中使用array[1:10, 3:20]选取第1到9行,第3到19列作为输出的切片,非常的方便,但是list中二维列表要实现这个效果,则一般需要使用列表生成式,比如[arr[3:20] for arr in array[1:10],的确可以选择将list转换成numpy数组进行切片,然后再tolist即可,这也是很多场景下的优选方案,但是本篇博客要说明的还是为什么?如何实现的?因此下面将实现一个简单的自定义数组类,不依赖任何第三方库,实现类似numpy数组的多维度直接切片的效果。

正文

切片的本质还是索引元素,因此背后调用的都是__getitem__方法,区别是索引元素传入的是int索引,而切片传入的是一个slice对象,而进一步,我们想要实现的类似array[1:10, 3:20]这种多维度切片,传入的则是slice对象组成的tuple元组。因此下一步就很简单,需要对传入的这个slice元组进行处理,第一步就是对slice的不同情况进行考虑,因为slice也可能不是slice,只是一个int,就像前面说的,还可能是...省略号,表示剩下所有维度,如果是int,比如2,我们就需要转换成等价的slice形式slice(2,3),如果是省略号,则使用slice(None)替代即可,表示不进行该维度的切片。大致处理代码如下:

def _parse_index(self, key):
        if isinstance(key, int):
            return slice(key, key+1)
        elif isinstance(key, slice):
            return key
        elif key is Ellipsis:
            return slice(None)
        else:
            raise TypeError(f"Unsupported index type: {type(key)}")

传入__getitem__的参数全部这样处理完之后,还需要进行维度的补全,比如使用了省略号,或者单纯就是切片维度小于实际维度的情况下,我们默认就是后面没写到的维度不做切片,因此需要全部补全成slice(None):

slices = []
        for k in key:
            slices.append(self._parse_index(k))
        while len(slices) < len(self._shape):
            slices.append(slice(None))

需要说明的是,我们使用的是原始的list列表,因此只能获取第一维度的len,不能像numpy的数组那样直接通过shape获取完整的维度信息,因此上面的self._shape需要我们自己进行计算,计算的思路很简单,使用递归即可,先用len获取第一维度,然后对剩下的维度依次使用len进行递归即可,递归终止条件就是传入的数据不再是list列表了,而是变成一个具体的数据了,比如说int:

def _get_shape(self, data):
        if isinstance(data, list):
            return [len(data)] + self._get_shape(data[0])
        else:
            return []

现在我们有了实际的列表数据,有了所有维度的切片数据,接下来就是针对每一个维度进行实际的切片,上面我们已经计算出列表的维度信息,因此是可以直接循环进行处理的,但是相对复杂一些,我们还可以再次利用递归进行处理,先获取第一个slice,根据这个slice切片对当前维度的数据进行切片,然后递归处理下级维度,利用剩下的slice对当前切片后的list数据继续递归处理,代码如下:

def _slice_nd_list(self, lst, slices):        
        # 递归终止条件
        if len(slices) == 0:
            return lst
        
        current_slice = slices
        remaining_slices = slices[1:]
        
        # 处理当前维度切片
        dim_size = len(lst)
        start, stop, step = current_slice[0].indices(dim_size)
        sliced = [lst[i] for i in range(start, stop, step)]
        
        # 递归处理下级维度
        if remaining_slices:
            return [self._slice_nd_list(item, remaining_slices) for item in sliced]
        return sliced

到这里就差不多结束了,已经实现了普通list列表的多维度直接切片的效果了,完整的代码:


class SimpleArray:
    def __init__(self, data):
        self._data = list(data)
        self._shape = self._get_shape(self._data)
    
    @property
    def shape(self):
        return self._shape

    def _get_shape(self, data):
        if isinstance(data, list):
            return [len(data)] + self._get_shape(data[0])
        else:
            return []

    def _slice_nd_list(self, lst, slices):        
        # 递归终止条件
        if len(slices) == 0:
            return lst
        
        current_slice = slices
        remaining_slices = slices[1:]
        
        # 处理当前维度切片
        dim_size = len(lst)
        start, stop, step = current_slice[0].indices(dim_size)
        sliced = [lst[i] for i in range(start, stop, step)]
        
        # 递归处理下级维度
        if remaining_slices:
            return [self._slice_nd_list(item, remaining_slices) for item in sliced]
        return sliced
        
    def _parse_index(self, key):
        if isinstance(key, int):
            return slice(key, key+1)
        elif isinstance(key, slice):
            return key
        elif key is Ellipsis:
            return slice(None)
        else:
            raise TypeError(f"Unsupported index type: {type(key)}")

    def __getitem__(self, key):
        if not isinstance(key, tuple):
            key = (key,)
        
        # 处理切片不足的情况(自动补全剩余维度)
        slices = []
        for k in key:
            slices.append(self._parse_index(k))
        while len(slices) < len(self._shape):
            slices.append(slice(None))
        
        # 生成视图数据
        view_data = self._slice_nd_list(self._data, slices)
        return SimpleArray(view_data)
    
    def __repr__(self):
        return f"SimpleArray({self._data})"

# 测试用例
if __name__ == "__main__":
    arr = SimpleArray([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
    print(arr[1:3, 0:2])  # 输出: SimpleArray([[5, 6], [9, 10]])
    print(arr[0, ...])    # 输出: SimpleArray([[1, 2, 3, 4]])
    print(arr[:, 1:3])    # 输出: SimpleArray([[2, 3], [6, 7], [10, 11]])

通过测试用例,我们可以看到效果,和numpy的切片完全一致了,这其实也就是numpy切片底层实现的大概原理,希望这篇博客能够透过这种底层原理让读者更好的了解多维度切片。