快速开始
备注
阅读本篇前,请确保已按照 安装教程 准备好昇腾环境及 PyG
本文档帮助昇腾开发者快速使用 PyG × 昇腾 进行 GNN 训练。你可以访问 这篇官方论文 获取更多信息。
概览
PyG 包含了针对图 (Graph) 及其他不规则结构的多种深度学习方法,这些方法来自众多已发表的论文。
训练示例
示例训练实现了针对引用图中论文分类的图神经网络的训练。首先加载 Cora 数据集,并使用预定义的 GCNConv 创建了一个简单的两层 GCN 模型,然后开始训练。
GCN (Graph Convolutional Network) 是一种经典的图神经网络架构,适用于处理图结构数据。GCN 通过在图上进行卷积操作来捕捉节点之间的关系,从而实现节点分类、图分类等任务。
1import torch
2from torch import Tensor
3from torch_geometric.nn import GCNConv
4from torch_geometric.datasets import Planetoid
5
6dataset = Planetoid(root='.', name='Cora')
7
8class GCN(torch.nn.Module):
9 def __init__(self, in_channels, hidden_channels, out_channels):
10 super().__init__()
11 self.conv1 = GCNConv(in_channels, hidden_channels)
12 self.conv2 = GCNConv(hidden_channels, out_channels)
13
14 def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
15 # x: Node feature matrix of shape [num_nodes, in_channels]
16 # edge_index: Graph connectivity matrix of shape [2, num_edges]
17 x = self.conv1(x, edge_index).relu()
18 x = self.conv2(x, edge_index)
19 return x
20
21model = GCN(dataset.num_features, 16, dataset.num_classes)
该示例会自动下载 Cora 数据集,并使用 GCN 模型以备训练。你可以根据需要修改模型结构、训练参数等,以适应不同的任务和数据集。
也可以使用 PyG 提供的示例代码进行训练,例如使用官方提供的 GCN 示例进行训练:
1python ./examples/gcn.py
看到类似如下输出,loss 存在明显下降趋势说明训练成功:
1Epoch: 001, Loss: 1.9458, Train: 0.2286, Val: 0.2580, Test: 0.2510
2Epoch: 002, Loss: 1.9432, Train: 0.3857, Val: 0.2300, Test: 0.2510
3Epoch: 003, Loss: 1.9374, Train: 0.7643, Val: 0.4500, Test: 0.4500
4Epoch: 004, Loss: 1.9290, Train: 0.7286, Val: 0.3860, Test: 0.4500
5Epoch: 005, Loss: 1.9223, Train: 0.7643, Val: 0.4080, Test: 0.4500
6Epoch: 006, Loss: 1.9178, Train: 0.7286, Val: 0.4140, Test: 0.4500
7Epoch: 007, Loss: 1.9078, Train: 0.7786, Val: 0.4380, Test: 0.4500
8Epoch: 008, Loss: 1.8984, Train: 0.7429, Val: 0.4320, Test: 0.4500
9Epoch: 009, Loss: 1.8855, Train: 0.7786, Val: 0.4040, Test: 0.4500
10Epoch: 010, Loss: 1.8776, Train: 0.7786, Val: 0.4220, Test: 0.4500
11Epoch: 011, Loss: 1.8649, Train: 0.7714, Val: 0.4360, Test: 0.4500
12Epoch: 012, Loss: 1.8561, Train: 0.8143, Val: 0.4900, Test: 0.4920
13Epoch: 013, Loss: 1.8508, Train: 0.8429, Val: 0.5400, Test: 0.5440
14Epoch: 014, Loss: 1.8339, Train: 0.8786, Val: 0.5680, Test: 0.5970
15Epoch: 015, Loss: 1.8220, Train: 0.9071, Val: 0.6140, Test: 0.6470
16Epoch: 016, Loss: 1.8028, Train: 0.9214, Val: 0.6540, Test: 0.6830
17Epoch: 017, Loss: 1.7940, Train: 0.9214, Val: 0.6800, Test: 0.7130
18Epoch: 018, Loss: 1.7733, Train: 0.9286, Val: 0.6900, Test: 0.7150
19Epoch: 019, Loss: 1.7586, Train: 0.9357, Val: 0.6920, Test: 0.7290
20Epoch: 020, Loss: 1.7426, Train: 0.9357, Val: 0.6980, Test: 0.7380
21Epoch: 021, Loss: 1.7214, Train: 0.9429, Val: 0.7040, Test: 0.7430
22Epoch: 022, Loss: 1.7060, Train: 0.9429, Val: 0.7080, Test: 0.7460
23Epoch: 023, Loss: 1.6939, Train: 0.9429, Val: 0.7200, Test: 0.7500
24Epoch: 024, Loss: 1.6736, Train: 0.9429, Val: 0.7160, Test: 0.7500
25Epoch: 025, Loss: 1.6517, Train: 0.9429, Val: 0.7180, Test: 0.7500
26Epoch: 026, Loss: 1.6458, Train: 0.9429, Val: 0.7220, Test: 0.7370
27Epoch: 027, Loss: 1.6297, Train: 0.9429, Val: 0.7240, Test: 0.7380
28Epoch: 028, Loss: 1.5822, Train: 0.9429, Val: 0.7140, Test: 0.7380
29Epoch: 029, Loss: 1.5706, Train: 0.9429, Val: 0.7120, Test: 0.7380
30Epoch: 030, Loss: 1.5858, Train: 0.9429, Val: 0.7220, Test: 0.7380
31Epoch: 031, Loss: 1.5373, Train: 0.9429, Val: 0.7300, Test: 0.7500
32Epoch: 032, Loss: 1.5358, Train: 0.9429, Val: 0.7260, Test: 0.7500
33Epoch: 033, Loss: 1.5177, Train: 0.9429, Val: 0.7300, Test: 0.7500
34Epoch: 034, Loss: 1.4543, Train: 0.9429, Val: 0.7420, Test: 0.7660
35Epoch: 035, Loss: 1.4536, Train: 0.9429, Val: 0.7520, Test: 0.7740
36Epoch: 036, Loss: 1.4642, Train: 0.9429, Val: 0.7560, Test: 0.7740
37Epoch: 037, Loss: 1.4009, Train: 0.9500, Val: 0.7620, Test: 0.7780
38Epoch: 038, Loss: 1.3986, Train: 0.9500, Val: 0.7560, Test: 0.7780
39Epoch: 039, Loss: 1.3620, Train: 0.9500, Val: 0.7520, Test: 0.7780
40Epoch: 040, Loss: 1.3841, Train: 0.9500, Val: 0.7580, Test: 0.7780
41Epoch: 041, Loss: 1.3488, Train: 0.9500, Val: 0.7700, Test: 0.7800
42Epoch: 042, Loss: 1.3262, Train: 0.9571, Val: 0.7680, Test: 0.7800
43Epoch: 043, Loss: 1.2861, Train: 0.9571, Val: 0.7760, Test: 0.7850
44Epoch: 044, Loss: 1.2833, Train: 0.9571, Val: 0.7800, Test: 0.7880
45Epoch: 045, Loss: 1.2255, Train: 0.9571, Val: 0.7660, Test: 0.7880
46Epoch: 046, Loss: 1.2127, Train: 0.9500, Val: 0.7620, Test: 0.7880
47Epoch: 047, Loss: 1.2455, Train: 0.9571, Val: 0.7660, Test: 0.7880
48Epoch: 048, Loss: 1.1698, Train: 0.9571, Val: 0.7660, Test: 0.7880
49Epoch: 049, Loss: 1.1380, Train: 0.9500, Val: 0.7680, Test: 0.7880
50Epoch: 050, Loss: 1.1567, Train: 0.9500, Val: 0.7680, Test: 0.7880
51Epoch: 051, Loss: 1.1356, Train: 0.9500, Val: 0.7680, Test: 0.7880
52Epoch: 052, Loss: 1.1302, Train: 0.9571, Val: 0.7680, Test: 0.7880
53Epoch: 053, Loss: 1.0982, Train: 0.9571, Val: 0.7640, Test: 0.7880
54Epoch: 054, Loss: 1.0880, Train: 0.9571, Val: 0.7620, Test: 0.7880
55Epoch: 055, Loss: 1.0617, Train: 0.9571, Val: 0.7580, Test: 0.7880
56Epoch: 056, Loss: 1.0410, Train: 0.9643, Val: 0.7600, Test: 0.7880
57Epoch: 057, Loss: 1.0352, Train: 0.9643, Val: 0.7620, Test: 0.7880
58Epoch: 058, Loss: 1.0271, Train: 0.9643, Val: 0.7680, Test: 0.7880
59Epoch: 059, Loss: 0.9928, Train: 0.9643, Val: 0.7680, Test: 0.7880
60Epoch: 060, Loss: 1.0205, Train: 0.9643, Val: 0.7720, Test: 0.7880
61Epoch: 061, Loss: 1.0038, Train: 0.9643, Val: 0.7740, Test: 0.7880
62Epoch: 062, Loss: 0.9809, Train: 0.9643, Val: 0.7740, Test: 0.7880
63Epoch: 063, Loss: 0.9509, Train: 0.9643, Val: 0.7740, Test: 0.7880
64Epoch: 064, Loss: 0.9133, Train: 0.9643, Val: 0.7720, Test: 0.7880
65Epoch: 065, Loss: 0.9303, Train: 0.9643, Val: 0.7740, Test: 0.7880
66Epoch: 066, Loss: 0.9378, Train: 0.9643, Val: 0.7780, Test: 0.7880
67Epoch: 067, Loss: 0.8676, Train: 0.9643, Val: 0.7840, Test: 0.8110
68Epoch: 068, Loss: 0.8609, Train: 0.9714, Val: 0.7840, Test: 0.8110
69Epoch: 069, Loss: 0.8127, Train: 0.9643, Val: 0.7880, Test: 0.8200
70Epoch: 070, Loss: 0.8994, Train: 0.9714, Val: 0.7880, Test: 0.8200
71Epoch: 071, Loss: 0.7771, Train: 0.9714, Val: 0.7920, Test: 0.8180
72Epoch: 072, Loss: 0.8375, Train: 0.9714, Val: 0.7880, Test: 0.8180
73Epoch: 073, Loss: 0.8174, Train: 0.9714, Val: 0.7900, Test: 0.8180
74Epoch: 074, Loss: 0.7833, Train: 0.9714, Val: 0.7920, Test: 0.8180
75Epoch: 075, Loss: 0.7510, Train: 0.9714, Val: 0.7900, Test: 0.8180
76Epoch: 076, Loss: 0.7898, Train: 0.9714, Val: 0.7880, Test: 0.8180
77Epoch: 077, Loss: 0.7931, Train: 0.9786, Val: 0.7840, Test: 0.8180
78Epoch: 078, Loss: 0.7608, Train: 0.9786, Val: 0.7860, Test: 0.8180
79Epoch: 079, Loss: 0.7193, Train: 0.9786, Val: 0.7840, Test: 0.8180
80Epoch: 080, Loss: 0.6972, Train: 0.9786, Val: 0.7900, Test: 0.8180
81Epoch: 081, Loss: 0.7126, Train: 0.9857, Val: 0.7860, Test: 0.8180
82Epoch: 082, Loss: 0.7176, Train: 0.9857, Val: 0.7840, Test: 0.8180
83Epoch: 083, Loss: 0.7042, Train: 0.9786, Val: 0.7800, Test: 0.8180
84Epoch: 084, Loss: 0.6833, Train: 0.9786, Val: 0.7820, Test: 0.8180
85Epoch: 085, Loss: 0.6981, Train: 0.9786, Val: 0.7880, Test: 0.8180
86Epoch: 086, Loss: 0.6565, Train: 0.9786, Val: 0.7880, Test: 0.8180
87Epoch: 087, Loss: 0.6837, Train: 0.9786, Val: 0.7860, Test: 0.8180
88Epoch: 088, Loss: 0.7371, Train: 0.9786, Val: 0.7900, Test: 0.8180
89Epoch: 089, Loss: 0.6373, Train: 0.9786, Val: 0.7940, Test: 0.8240
90Epoch: 090, Loss: 0.6574, Train: 0.9786, Val: 0.7980, Test: 0.8250
91Epoch: 091, Loss: 0.6248, Train: 0.9786, Val: 0.7980, Test: 0.8250
92Epoch: 092, Loss: 0.6330, Train: 0.9786, Val: 0.8020, Test: 0.8180
93Epoch: 093, Loss: 0.7066, Train: 0.9786, Val: 0.8000, Test: 0.8180
94Epoch: 094, Loss: 0.5868, Train: 0.9786, Val: 0.8060, Test: 0.8230
95Epoch: 095, Loss: 0.6133, Train: 0.9786, Val: 0.8040, Test: 0.8230
96Epoch: 096, Loss: 0.5794, Train: 0.9786, Val: 0.7960, Test: 0.8230
97Epoch: 097, Loss: 0.5593, Train: 0.9786, Val: 0.7880, Test: 0.8230
98Epoch: 098, Loss: 0.5757, Train: 0.9786, Val: 0.7840, Test: 0.8230
99Epoch: 099, Loss: 0.6419, Train: 0.9857, Val: 0.7820, Test: 0.8230
100Epoch: 100, Loss: 0.5809, Train: 0.9857, Val: 0.7780, Test: 0.8230