DGL学习(三): 消息传递教程

摘要:
消息传递和函数转换是用户定义的函数(UDF)。P) G=dgl。DGLGraph(g)src=list(范围(1,dst=[0]*50#使用列表批量添加g.add_edges(src,node_size=50,node_color=[[.5,

在本节中,我们将不同级别的消息传递API与PageRank一起使用。 在DGL中,消息传递和功能转换是用户定义的函数(UDF)。

PageRank 算法:

在PageRank的每次迭代中,每个节点(网页)首先将其PageRank值均匀地分散到其下游节点。 每个节点的新PageRank值是通过汇总从其邻居收到的PageRank值来计算的,然后通过阻尼因子(damping factor)进行调整:

DGL学习(三): 消息传递教程第1张

 生成一个随机图, 两点之间有边的概率为 P:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl

N = 100
P = 0.1
DAMP = 0.8
g = nx.erdos_renyi_graph(N, P) g = dgl.DGLGraph(g)
src = list(range(1,51));dst = [0]*50 # 使用list批量添加
g.add_edges(src, dst)
print(g.number_of_edges()) print(g.number_of_nodes()) nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show() 

DGL学习(三): 消息传递教程第2张

在pagerank 中, 初始化每个节点初始值为 1/N, 将节点的出度作为节点的特征。

## pv 算法初始值
g.ndata['pv'] = torch.ones(N) / N
g.ndata['deg'] = g.out_degrees(g.nodes()).float()

定义消息函数,该函数将每个节点的PageRank值除以其出度,然后将结果作为消息传递给其邻居。

在DGL中,消息函数是针对边的,表示为Edge UDF。 Edge UDF接受单个参数edges。 它具有三个成员src,dst和data,用于访问源节点特征,目标节点特征和边特征。实现pv算法仅需从src中取特征。

def pagerank_message_func(edges):
    return {'pv': edges.src['pv'] / edges.src['deg']}

定义reduce函数,该函数从其mailbox中聚合消息和删除消息,并计算其新的PageRank值。

reduce函数是针对节点的,表示为 Node UDF。 Node UDF接受单个参数nodes,nodes具有两个成员mailbox和data。 data包含节点特征,mailbox包含所有传入消息特征,这些功能沿第二维堆叠(dim = 1参数)。

可以结合下图进行理解:

DGL学习(三): 消息传递教程第3张

def pagerank_reduce_func(nodes):
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {'pv' : pv}

注册消息函数和规约函数, 之后DGL调用它。 pagerank_naive是page_rank的简单实现。

# 注册消息函数和归约函数,稍后DGL将调用它。
g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func)

def pagerank_naive(g):
    # Phase #1: send out messages along all edges.
    for u, v in zip(*g.edges()):
        g.send((u, v))
    # Phase #2: receive messages to compute new PageRank values.
    for v in g.nodes():
        g.recv(v)

# 迭代10轮
for k in range(10):
    pagerank_naive(g)

print(g.ndata['pv'])
DGL学习(三): 消息传递教程第4张DGL学习(三): 消息传递教程第5张
tensor([0.0446, 0.0107, 0.0087, 0.0102, 0.0085, 0.0130, 0.0091, 0.0059, 0.0079,
        0.0088, 0.0082, 0.0087, 0.0098, 0.0087, 0.0100, 0.0092, 0.0065, 0.0168,
        0.0064, 0.0106, 0.0098, 0.0117, 0.0077, 0.0113, 0.0111, 0.0100, 0.0077,
        0.0051, 0.0084, 0.0070, 0.0048, 0.0163, 0.0102, 0.0084, 0.0098, 0.0127,
        0.0101, 0.0091, 0.0091, 0.0083, 0.0088, 0.0095, 0.0132, 0.0106, 0.0057,
        0.0099, 0.0068, 0.0106, 0.0098, 0.0068, 0.0140, 0.0087, 0.0083, 0.0120,
        0.0107, 0.0109, 0.0072, 0.0090, 0.0069, 0.0124, 0.0094, 0.0106, 0.0071,
        0.0093, 0.0070, 0.0059, 0.0068, 0.0162, 0.0082, 0.0129, 0.0063, 0.0134,
        0.0116, 0.0095, 0.0107, 0.0147, 0.0085, 0.0099, 0.0084, 0.0069, 0.0112,
        0.0120, 0.0076, 0.0105, 0.0125, 0.0091, 0.0063, 0.0085, 0.0051, 0.0102,
        0.0116, 0.0070, 0.0120, 0.0094, 0.0156, 0.0159, 0.0096, 0.0125, 0.0065,
        0.0107])
View Code

大图的批处理语义

上图中的方法需要遍历所有节点,不适合于大图,DGL通过允许在一个batch的节点或边上进行计算来解决此问题。 例如,以下代码一次性触发所有多个节点的消息函数和规约函数。

def pagerank_batch(g):
    g.send(g.edges())
    g.recv(g.nodes())
for k in range(10):
    #pagerank_naive(g)
    pagerank_batch(g)
print(g.ndata['pv'])

并行性方面:  由于每个节点接受的输出参数是不同的,不同长度的张量没法进行stack。所以DGL按传入消息的数量对节点进行分组,分组调用reduce函数来解决该问题。

使用更高级别的API来提高效率

def pagerank_level2(g):
    g.update_all()

使用内置API

一些常用的消息函数和规约函数DGL都包含了,直接调用即可。

import dgl.function as fn

def pagerank_builtin(g):
    g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
    g.update_all(message_func=fn.copy_src(src='pv', out='m'),
                 reduce_func=fn.sum(msg='m',out='m_sum'))
    g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']

免责声明:文章转载自《DGL学习(三): 消息传递教程》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇目录启动CXF启动报告LinkageError异常以及Java的endorsed机制ElasticSearch 7.14安装步骤【windows平台】下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

Hive函数:SUM,AVG,MIN,MAX

转自:http://lxw1234.com/archives/2015/04/176.htm,Hive分析窗口函数(一) SUM,AVG,MIN,MAX 之前看到大数据田地有关于max()over(partition by)的用法,今天恰好工作中用到了它,但是使用中遇到了一个问题:在max(rsrp)over(partition by buildingid...

微信小程序setdata修改数组或对象

1、this.setdata修改数组的固定一项的值 changeItemInArr: function() { this.setData({ 'arr[0].text':'changed data' }) }, 2、动态修改数组某一项的值 changeItemInArr: function(index) { let...

vue.js中v-for的使用及索引获取

2.x版本: v-for="(item,index) in items" index即索引值。  ==========================分割线============================== 1.x版本: 1.v-for   示例一: <!DOCTYPE html> <html> <head>...

poco脚本编写之api

连接设备后使用poco   使用connect_device连接好指定设备后,会返回一个Device对象,将这个对象传入 AndroidUiautomationPoco第一个参数里进行poco的初始化,   接下来使用此poco 实例将会获取所指定的设备的UI和对其进行操作。      from airtest.core.api import connec...

Oracle-11g-R2 于 Linux 上的 RAC 卸载

安装环境: SuSE Linux Enterprise Server 11 SP3 Oracle 11g 11.2.0.3   卸载步骤: 1.卸载 Database 软件(oracle,第一节点) (1).运行 $ORACLE_HOME/deinstall/deinstall 脚本。 (2).按照导航执行如下。 (2-1).Specify th...

Kaggle系列1:手把手教你用tensorflow建立卷积神经网络实现猫狗图像分类

去年研一的时候想做kaggle上的一道题目:猫狗分类,但是苦于对卷积神经网络一直没有很好的认识,现在把这篇文章的内容补上去。(部分代码参考网上的,我改变了卷积神经网络的网络结构,其实主要部分我加了一层1X1的卷积层,至于作用,我会在后文详细介绍) 题目地址:猫狗大战 同时数据集也可以在上面下载到。 既然是手把手,那么就要从前期的导入数据开始: 导入...