拼接与拆分
拼接:Cat
,Stack
拆分:Split
,Chunk
Cat
torch.cat(tensors,dim)
按维度拼接张量。
参数说明:
tensors
:要拼接的张量序列,可以是一个列表或元组。dim
:指定拼接的维度,默认为 0。
使用示例
1 | a = torch.rand(4,32,8) |
输出示例
1 | torch.Size(9,32,8) |
Stack
torch.stack(tensors,dim)
也能完成拼接,与cat
不同的是,会在指定维度前创建一个新的维度。
使用示例
1 | a = torch.rand(4,32,8) |
输出示例
1 | torch.Size(2,4,32,8) |
要拼接的两个维度必须一模一样。
Split
torch.split(tensors,split_size_or_sections,dim)
是按照长度len
来拆分的。
tensor
:要拆分的输入张量。split_size_or_sections
:指定拆分的大小或拆分的数量,可以是一个整数或一个列表。dim
:可选参数,指定拆分的维度,默认为 0。
使用示例
1 | a = torch.rand(4,32,8) |
输出示例
1 | (torch.Size([2, 32, 8]), torch.Size([2, 32, 8])) |
Chunk
torch.chunk(tensor, chunks, dim=0)
是按照数量num
拆分
tensor
:要拆分的输入张量。chunks
:指定拆分成的子张量的数量。dim
:可选参数,指定拆分的维度,默认为 0。
使用示例
1 | a = torch.rand(4,32,8) |
输出示例
1 | (torch.Size([2, 32, 8]), torch.Size([2, 32, 8])) |
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.