阿里云-云小站(无限量代金券发放中)
【腾讯云】云服务器、云数据库、COS、CDN、短信等热卖云产品特惠抢购

Pytorch实现List Tensor转Tensor,reshape拼接等操作

35次阅读
没有评论

共计 1444 个字符,预计需要花费 4 分钟才能阅读完成。

导读 这篇文章主要介绍了 Pytorch 实现 List Tensor 转 Tensor,reshape 拼接等操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

持续更新一些常用的 Tensor 操作,比如 List,Numpy,Tensor 之间的转换,Tensor 的拼接,维度的变换等操作。

其它 Tensor 操作如 einsum 等见:待更新。

用到两个函数:

  • torch.cat
  • torch.stack
  • 一、List Tensor 转 Tensor (torch.cat)

    Pytorch 实现 List Tensor 转 Tensor,reshape 拼接等操作

    // An highlighted block
    >>> t1 = torch.FloatTensor([[1,2],[5,6]])
    >>> t2 = torch.FloatTensor([[3,4],[7,8]])
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> ta = torch.cat(l,dim=0)
    >>> ta = torch.cat(l,dim=0).reshape(2,2,2)
    >>> tb = torch.cat(l,dim=1).reshape(2,2,2)
    >>> ta
    tensor([[[1., 2.],
             [5., 6.]],
     
            [[3., 4.],
             [7., 8.]]])
    >>> tb
    tensor([[[1., 2.],
             [3., 4.]],
     
            [[5., 6.],
             [7., 8.]]])
    高维 tensor

    ** 如果理解了 2D to 3DTensor, 以此类推,不难理解 3D to 4D,看下面代码即可明白:**

    >>> t1 = torch.range(1,8).reshape(2,2,2)
    >>> t2 = torch.range(11,18).reshape(2,2,2)
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> torch.cat(l,dim=2).reshape(2,2,2,2)
    tensor([[[[1.,  2.],
              [11., 12.]],
     
             [[3.,  4.],
              [13., 14.]]],
     
     
            [[[5.,  6.],
              [15., 16.]],
     
             [[7.,  8.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=1).reshape(2,2,2,2)
    tensor([[[[1.,  2.],
              [3.,  4.]],
     
             [[11., 12.],
              [13., 14.]]],
     
     
            [[[5.,  6.],
              [7.,  8.]],
     
             [[15., 16.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=0).reshape(2,2,2,2)
    tensor([[[[1.,  2.],
              [3.,  4.]],
     
             [[5.,  6.],
              [7.,  8.]]],
     
     
            [[[11., 12.],
              [13., 14.]],
     
             [[15., 16.],
              [17., 18.]]]])
    二、List Tensor 转 Tensor (torch.stack)

    Pytorch 实现 List Tensor 转 Tensor,reshape 拼接等操作

    代码:

    import torch
     
    t1 = torch.FloatTensor([[1,2],[5,6]])
    t2 = torch.FloatTensor([[3,4],[7,8]])
    l = [t1, t2]
     
    t3 = torch.stack(l, dim=2)
    print(t3.shape)
    print(t3)
     
    ## output:
    ## torch.Size([2, 2, 2])
    ## tensor([[[1., 3.],
    ##          [2., 4.]],
    ##        [[5., 7.],
    ##         [6., 8.]]])

    阿里云 2 核 2G 服务器 3M 带宽 61 元 1 年,有高配

    腾讯云新客低至 82 元 / 年,老客户 99 元 / 年

    代金券:在阿里云专用满减优惠券

    正文完
    星哥说事-微信公众号
    post-qrcode
     0
    星锅
    版权声明:本站原创文章,由 星锅 于2024-07-24发表,共计1444字。
    转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
    【腾讯云】推广者专属福利,新客户无门槛领取总价值高达2860元代金券,每种代金券限量500张,先到先得。
    阿里云-最新活动爆款每日限量供应
    评论(没有评论)
    验证码
    【腾讯云】云服务器、云数据库、COS、CDN、短信等云产品特惠热卖中