PyTorch笔记之 scatter() 函数

摘要:
Scatter()和Scatter_()的函数相同,只是Scatter()不直接修改原始张量_()在PyTorch中,带下划线的通用函数表示直接在原始张量上修改散射的参数(dim、index、src)。有三个维度:哪个维度被索引。index:用于散点的元素索引。src:用于散点的源元素。它可以是标量或张量

scatter()scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会

PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改

scatter(dim, index, src) 的参数有 3 个

  • dim:沿着哪个维度进行索引
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素,可以是一个标量或一个张量

这个 scatter  可以理解成放置元素或者修改元素

简单说就是通过一个张量 src  来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定

官方文档给出了 3维张量 的具体操作说明,如下所示

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

exmaple:

x = torch.rand(2, 5)

#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

具体地说,我们的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),一个二维张量,下面用图简单说明

我们是 2维 张量,一开始进行 $self[index[0][0]][0]$,其中 $index[0][0]$ 的值是0,所以执行 $self[0][0] = x[0][0] = 0.1940$ 

$self[index[i][j]][j] = src[i][j] $

PyTorch笔记之 scatter() 函数第1张再比如$self[index[1][0]][0]$,其中 $index[1][0]$ 的值是2,所以执行 $self[2][0] = x[1][0] = 0.2078$ 

PyTorch笔记之 scatter() 函数第2张

src 除了可以是张量外,也可以是一个标量

example:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)

#tensor([[7., 7., 7., 7., 7.],
#        [0., 7., 0., 7., 0.],
#        [7., 0., 7., 0., 7.]]

scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子

example:

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
#        [0],
#        [3],
#        [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
#        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

免责声明:文章转载自《PyTorch笔记之 scatter() 函数》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇Java File类的简单使用ThinkPHP的连贯操作方法中field方法下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

VBA学习_1:数据类型

VBA的数据类型 布尔型Boolean 整数:整数型Integer、字节型Byte、长整数型Long 小数:小数型Decimal、单精度浮点型Single、双精度浮点型Double、货币型Currency 字符串型Sting(定长和不定长) 日期型Date 对象型Object 变体型Variant 用户自定义类型 声明变量   Dim 变量名 As 数...

CNN中的卷积

1、什么是卷积:图像中不同数据窗口的数据和卷积核(一个滤波矩阵)作内积的操作叫做卷积。其计算过程又称为滤波(filter),本质是提取图像不同频段的特征。 2、什么是卷积核:也称为滤波器filter,带着一组固定权重的神经元,通常是n*m二维的矩阵,n和m也是神经元的感受野。n*m 矩阵中存的是对感受野中数据处理的系数。一个卷积核的滤波可以用来提取特定的特...

如何修改已有的ONNX模型

简单来说,我们只需要学习一下把大象如何放进冰箱的就行了: 1、把冰箱门打开 使用onnx的原生接口: onnx_model = onnx.load(onnx_path) graph = onnx_model.graph 这样我们就可以将模型load出来,并且到到graph信息。 2、把大象放进去 这一步相对来说选择就比较多了,比如你可以选择删除一些节点,...

神经网络中的降维和升维方法 (tensorflow & pytorch)

  大名鼎鼎的UNet和我们经常看到的编解码器模型,他们的模型都是先将数据下采样,也称为特征提取,然后再将下采样后的特征恢复回原来的维度。这个特征提取的过程我们称为“下采样”,这个恢复的过程我们称为“上采样”,本文就专注于神经网络中的下采样和上采样来进行一次总结。写的不好勿怪哈。 神经网络中的降维方法 池化层   池化层(平均池化层、最大池化层),卷积...

pytorch学习问题汇总

问题六: 问题五:这里是怎么得到的? 问题四:为什么会是如下结果? torch.bernoulli(a)怎么是这个结果? 问题1:torch各个类型数据格式如何转换?数据类型在官方文档torch.Tensor中,有八种类型。 #尝试一 i32=torch.IntTensor([1,2,3]) i64=torch.LongTensor([1,2,3])...

[vba]excel中求选中数据和为给定数所有的组合

昨天下午开始学习的vba,累死了,肯定有bug,待调试 vba程序如下: 1 Dim aSum As Integer 2 Dim tSum As Integer 3 Dim judge(30) As Integer 4 Dim arrMax As Integer 5 Dim arr 6 Dim location(30) As Integer...