【Graph Net学习】DeepWalk/Node2Vec实现Graph Embedding

2023-09-21 11:45:13

一、简介

        本文主要通过代码实战介绍基础的两种图嵌入方式DeepWalk、Node2Vec。

        DeepWalk(KDD 2014)首个影响至今的图的Embedding算法,DeepWalk算法是一种用于学习节点表示的方法,常用于网络图中的节点的嵌入表示。

模型目标输入输出
Word2VecWordSentenceWord Embedding
DeepWalkNodeNode SequenceNode Embedding

        算法流程:

        1.首先,DeepWalk算法会从一个随机的起始节点开始,比如说选择朋友A作为起点。然后,算法会从A的邻居节点中随机选择一个节点,比如说选择了B。接着,算法会再从B的邻居节点中随机选择一个节点,比如说选择了C。这样反复进行直到达到事先设定的步数。

        2.一旦完成了一次类似于“走迷宫”的遍历,DeepWalk算法会将这条路径视为一句话,其中包含了A、B和C三个节点。算法会重复这个过程,多次生成不同的句子。

        3.然后,DeepWalk算法会将这些句子作为文本输入给Word2Vec算法。

        Node2Vec:Node2Vec算法是一种能够学习网络节点表示的算法。它通过优化随机游走过程来最大化网络的邻域节点之间的相似度,从而得到每个节点的有效嵌入。

        算法流程:

  1. 首先,从图中的每个节点开始执行固定长度的随机游走。这个步骤旨在生成每个节点的上下文信息。随机游走的方法包括以一种偏好的方式回到先前访问的节点,或者探索之前未访问的节点。这通过参数p和q来调整,其中p控制返回预先访问节点的可能性,而q控制更偏向于访问较远的节点。

  2. 得到随机游走序列之后,使用Skip-gram模型训练节点嵌入。在Skip-gram模型中,我们试图预测节点的邻居节点。

  3. 在训练过程中,使用梯度下降等优化算法最小化预测错误,进而通过迭代更新嵌入向量,使得越相似的节点其嵌入向量越接近。

     4.最后经过这样的训练后,我们就可以获得每个节点的向量表示,这个向量反映了节点在网络中的位置和角色。

```

        直接刚代码,开箱即用。

二、代码

import torch
import numpy as np
import os
import random
import pandas as pd
import scipy.sparse as sp
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
from node2vec import Node2Vec
import networkx as nx
from gensim.models import Word2Vec

def seed_everything(seed=2023):
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything()

def load_cora_data(data_path = './data/cora'):

    content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)
    content_df.set_index(0, inplace=True)
    index = content_df.index.tolist()
    features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)

    # 处理标签
    labels = content_df.values[:,-1]
    class_encoder = LabelEncoder()
    labels = class_encoder.fit_transform(labels)

    # 读取引用关系
    cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)
    cites_df[0] = cites_df[0].astype(str)
    cites_df[1] = cites_df[1].astype(str)
    cites = [tuple(x) for x in cites_df.values]
    edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]
    edges = np.array(edges).T

    # 构造Data对象
    data = Data(x=torch.from_numpy(np.array(features.todense())),
                edge_index=torch.LongTensor(edges),
                y=torch.from_numpy(labels))

    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    # 读取Cora数据集 return geometric Data格式
    def index_to_mask(index, size):
        mask = np.zeros(size, dtype=bool)
        mask[index] = True
        return mask

    data.train_mask = index_to_mask(idx_train, size=labels.shape[0])
    data.val_mask = index_to_mask(idx_val, size=labels.shape[0])
    data.test_mask = index_to_mask(idx_test, size=labels.shape[0])

    def to_networkx(data):
        edge_index = data.edge_index.to(torch.device('cpu')).numpy()
        G = nx.DiGraph()
        for src, tar in edge_index.T:
            G.add_edge(src, tar)
        return G

    networkx_data = to_networkx(data)

    return data,networkx_data

#获取数据:pyg_data:torch_geometric格式;networkx_data:networkx格式
pyg_data,networkx_data = load_cora_data()

#Node2Vec_Embedding方法
def Node2Vec_run(networkx_data, dimensions=128, walk_length=30, num_walks=200):
    # 创建一个Node2Vec对象 #dimensions=64 embedding维度, walk_length=30 游走步长, num_walks=200 游走次数, workers=4 线程数
    node2vec = Node2Vec(networkx_data, dimensions=dimensions, walk_length=walk_length, num_walks=num_walks, workers=4)

    # 训练Node2Vec模型
    model = node2vec.fit(window=10, min_count=1, batch_words=4) #获得node2vec的所有内容
    nodes = model.wv.index_to_key  # 得到所有节点的名字
    embeddings = model.wv[nodes]  # 得到所有节点的嵌入向量
    return model,nodes,embeddings

def DeepWalk_run(networkx_data,dimensions = 128, walk_length = 30, num_walks = 200):
    # 使用deepwalk算法进行graph embedding
    # DeepWalk算法
    def deepwalk(graph, num_walks, walk_length):
        walks = []
        for node in graph.nodes():
            if graph.degree(node) == 0:
                continue
            for _ in range(num_walks):
                walk = [node]
                target = node
                for _ in range(walk_length - 1):
                    if len(list(graph.neighbors(target))) == 0:  # 判断当前节点是否有邻居,如果为空邻居,则跳过当前节点
                        continue
                    target = random.choice(list(graph.neighbors(target)))
                    walk.append(target)
                walks.append(walk)
        return walks
    walks = deepwalk(networkx_data, num_walks = num_walks, walk_length = walk_length)
    # 用Word2Vec训练节点向量
    model = Word2Vec(walks, vector_size=dimensions, window=5, min_count=0, sg=1) #参数sg=1表示选择Skip-Gram模型  window 影响着Word2Vec中词和其上下文词的最大距离
    nodes = model.wv.index_to_key  # 得到所有节点的名字
    embeddings = model.wv[nodes]  # 得到所有节点的嵌入向量
    return model,nodes,embeddings

_,_,node2vec_embeddings = Node2Vec_run(networkx_data,num_walks=1)
print("node2vec_embeddings :",np.array(node2vec_embeddings).shape) # print : "node2vec_embeddings : (2708, 64)"

_,_,DeepWalk_embeddings = DeepWalk_run(networkx_data,num_walks=1)
print("DeepWalk_embeddings :",np.array(DeepWalk_embeddings).shape) # print : "node2vec_embeddings : (2708, 64)"

三、结果及展望

       上述是针对Cora的数据集做的Node Embedding输出,输出为:node2vec_embeddings : (2708, 128);DeepWalk_embeddings : (2708, 128)

        接下来大家就可拿到 (2708, 128)这个Embedding做各种下游了,如聚类、Net Feature等

        P.S.这些都是18年以前,NN不发达的Embedding产物,并未挖掘深层feature的embedding,接下来玩一玩NN的Graph Embedding

更多推荐

LabVIEW开发气动悬浮系统教学平台

LabVIEW开发气动悬浮系统教学平台目前,通过使用可编程逻辑控制器,几乎可以实现任何工业生产过程的自动化。工业自动化可以提高流程效率,提高生产水平并减少损失。在此背景下,介绍了工业自动化教育系统的设计和实现以及气动悬浮过程中的控制应用。该自动化系统基于PLCS7-1500和LabVIEW中设计的人机界面,用于监测气动

游戏开发之路

最近即将大四,面临实习和就业的问题,学校只想尽快把我们推出去,却不管前方是刀山还是火海。如果没有梦想,去哪里都是流浪。如果怀有梦想,你是否迷茫?我不是985也不是211,我不想使用Unity或Unreal,明明什么都没有我却想做出惊艳的3A作品。但现在实现不了梦想没关系,十年后也许可以实现梦想,二十年后也许可以实现梦想

Ae 效果:CC Hair

模拟/CCHairSimulation/CCHairCCHair(CC毛发)可以在源图像上模拟生成毛发、绒线等,并可调整它们的长度、方向、重量等属性,从而创建出非常独特的效果。CCHair本质上是基于Alpha通道来生成毛发,无毛发处将变为透明。比如,对于文本等矢量图层,它会基于Alpha通道的轮廓来生成毛发。◆◆◆效

企业怎么优化固定资产管理

在优化固定资产管理的过程中,不仅要关注硬件设备和设施的维护,还要重视软件系统和数据管理。一些可能的方法:需要建立一套完整的资产管理系统。这个系统应该包括资产的采购、登记、使用、维修、报废等各个环节的管理流程。通过这个系统,可以实时了解每个资产的状态,及时发现并解决潜在的问题。应该对固定资产进行定期的盘点和维护。这不仅可

操作系统权限提升(二十六)之数据库提权-MySQL UDF提权

MySQLUDF提权MySQL介绍MySQL是最流行的开放源码SQL数据库管理系统,相对于Oracle,DB2等大型数据库系统,MySQL由于其开源性、易用性、稳定性等特点,受到个人使用者、中小型企业甚至一些大型企业的广泛欢迎,MySQL具有以下特点:1、MySQL是一种关联数据库管理系统,具有灵活性。2、MySQL软

mysql的判断语句

ifif用于做条件判断,具体的语法结构如下,在if条件判断的结构中,ELSEIF结构可以有多个,也可以没有。ELSE结构可以有,也可以没有。IF条件1THEN.....ELSEIF条件2THEN--可选.....ELSE--可选.....ENDIF;案例createprocedurep3()begindeclaresc

Linux(Centos)查看硬盘大小

Linux查看硬盘大小使用df命令:df命令可以用来显示文件系统的磁盘使用情况,包括每个挂载点的磁盘空间大小和使用情况。要查看硬盘大小,可以运行以下命令:df-h这将以人类可读的方式显示文件系统的磁盘大小,以GB或MB为单位。下面是df-h命令输出的参数说明:Filesystem:文件系统的名称或挂载点。这是磁盘空间的

Delft3D水动力与泥沙运动模拟实践应用

水体中泥沙运动是关系到防洪,调水等方面的重要问题,也是水利和水环境领域科研热点之一。水利数值模型,在环境影响评价、防洪规划等方面也有着广泛的应用。荷兰Delft研究所开发的Delft3D模型是世界上最先进的水动力之一,能够运用于河网、浅水湖泊、深水水库以及近岸海洋等多种水体的水动力和泥沙问题的研究中;同时,Delft3

MYSQL数据库基础

这里写目录标题MYSQL数据库基础一.数据库原理1.数据的时代2.数据库的发展史1)文件管理系统的缺点2)数据库系统发展阶段3)DBMS数据库管理系统4)数据库管理系统的优点5)数据库管理系统的基本功能6)数据库系统的架构7)各种数据库管理系统8)关系型数据库理论二.MYSQL历史关系型数据库和非关系型数据库三.mys

【Vue】路由与Node.js下载安装及环境配置教程

🎉🎉欢迎来到我的CSDN主页!🎉🎉🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚🌟推荐给大家我的专栏《Vue快速入门》。🎯🎯👉点击这里,就可以查看我的主页啦!👇👇Java方文山的个人主页🎁如果感觉还不错的话请给我点赞吧!🎁🎁💖期待你的加入,一起学习,一起进步!💖💖目录前言

访问学者申请一定要会说英语吗?

访问学者申请一定要会说英语吗?显然,出国做访问学者,外语是出国的关键,这是毋庸置疑,而且必须严格对待的。下面就随知识人网小编一起来深入探讨一下。首先,我们需要明确的是,访问学者申请通常要求申请者具备一定的英语能力。这是因为访问学者在国外学术机构或大学进行研究工作时,需要与导师、同事以及学生进行有效的沟通,而英语通常是国

热文推荐