拼接:Cat,Stack
拆分:Split,Chunk

Cat

torch.cat(tensors,dim)按维度拼接张量。
参数说明:

  • tensors:要拼接的张量序列,可以是一个列表或元组。
  • dim:指定拼接的维度,默认为 0。
    使用示例
1
2
3
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
torch.cat([a,b],dim = 0).shape

输出示例

1
torch.Size(9,32,8)

Stack

torch.stack(tensors,dim)也能完成拼接,与cat不同的是,会在指定维度前创建一个新的维度。

使用示例

1
2
3
a = torch.rand(4,32,8)
b = torch.rand(4,32,8)
torch.cat([a,b],dim = 0).shape

输出示例

1
torch.Size(2,4,32,8)

要拼接的两个维度必须一模一样。

Split

torch.split(tensors,split_size_or_sections,dim)是按照长度len来拆分的。

  • tensor:要拆分的输入张量。
  • split_size_or_sections:指定拆分的大小或拆分的数量,可以是一个整数或一个列表。
  • dim:可选参数,指定拆分的维度,默认为 0。
    使用示例
1
2
3
a = torch.rand(4,32,8)
aa,bb = torch.split(a,2,dim = 0)
aa.shape,bb.shape

输出示例

1
(torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))

Chunk

torch.chunk(tensor, chunks, dim=0)是按照数量num拆分

  • tensor:要拆分的输入张量。
  • chunks:指定拆分成的子张量的数量。
  • dim:可选参数,指定拆分的维度,默认为 0。
    使用示例
1
2
3
a = torch.rand(4,32,8)
aa,bb = torch.split(a,2,dim = 0)
aa.shape,bb.shape

输出示例

1
(torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))