Preliminaries
참고가 될 만한 포스팅 (중국어로 작성되어있어, 영어로 번역하여 보는 것을 추천)
지난번 포스팅을 통해 널리 알려진 GNN 라이브러리 중 하나인 PyTorch Geometric (이하 PyG)에 대해 간략하게나마 정리하였다. (링크)
본 포스팅에서는 PyG 활용 시 mini-batch training을 위한 대표적인 method 중 하나인 NeighborLoader에 대해 정리하고자 한다.
작성 편의상 sampling을 수행하는 기준 node, 즉 sampling을 수행하였을 때 최 상단에 위치하게 되는(기준이 되는) node를 anchor node로 작성하였다.
NeighborLoader의 인자
torch_geometric.loader
를 확인해보면 NeighborLoader
외에 다양한 loader class가 존재한다.
torch_geometric.loader.NeighborLoader
는 GraphSAGE에서 제안한 neighbor sampling을 수행하는 DataLoader로, full-batch training이 힘든 (single) large graph에 대해 mini-batch training을 수행할 수 있도록 하는 역할을 수행한다.
class의 인자를 보면 여러개가 있는데, (거의) 주로 사용하는 인자를 살펴보면 다음과 같다.
data
: 현재 sampling을 수행할 대상이 될 graph.num_neighbors
: sampling을 수행 할 이웃 node의 수.num_neighbors = [a, b]
과 같이 지정하는 경우, 총 2-hop 거리 만큼 까지 sampling을 수행할 것을 의미하며, 첫 번째 hop에선 a개의 이웃 node를, 두 번째 hop에선 b개의 이웃 node를 sampling 하게 된다.- 즉,
len(num_neighbors)
가 sampling에 고려할 hop 수를 의미한다. - GCN, GraphSAGE paper에서도 언급하였듯, GNN layer의 개수가 연산 수행 시 고려할 hop 수를 의미하므로,
len(num_neighbors) == number of layers == number of hop
이 된다.
batch_size
: 하나의 batch마다 포함될 anchor node의 개수.batch_size = B
로 설정하는 경우, 하나의 batch마다 B개의 anchor node가 포함되며, 각 anchor node마다num_neighbors
만큼의 node를 sampling하여 computation graph를 생성하게 된다.
input_nodes
: 이를 설정할 경우, batch마다len(input_nodes)
의 수 만큼을 anchor node로 설정해 computation graph를 생성.- 즉 각 batch 마다 최대 B개의 anchor node를 설정하고, 이 anchor node를
input_nodes
에서 선택하게 된다. - 만약
input_nodes
가 지정되지 않았다면 graph에 있는 모든 node에서 ordering 된 순서대로 anchor node를 설정하게 된다.
- 즉 각 batch 마다 최대 B개의 anchor node를 설정하고, 이 anchor node를
disjoint
:True
로 설정하는 경우, batch안의 구성되는 anchor node마다 서로 겹치지 않는(disjoint) node들을 sampling하여 disjoint한 computation graph를 생성한다. (기본은False
)- default인
False
일 때, 만일 한 batch에서 anchor node로 node 1, node 2가 함께 있고, 둘이 서로 연결된 node라면 computation graph를 생성할 때 계산의 중복이 발생할 수 있다. - 즉 node 1을 기준으로 한 computation graph에 node 2가 포함되고, node 2를 기준으로 한 computation graph에서 node 1이 포함될 수 있다. (각 computation graph에서 representation의 계산에 중복이 발생할 수 있다)
- default인
인자 중 batch_size
와 input_nodes
는 서로 상호보완(?)적이기도 하다.
- input_nodes를 지정하지 않는 경우 (
input_nodes=None
): 생성되는 batch의 갯수는total_nodes / batch_size
가 되고, 각 batch마다 설정되는 anchor node의 수는batch_size
가 된다.- 특정 node들을 anchor node로 설정하라는 명시가 없으므로, 전체 graph에 있는 모든 node들을 대상으로 node의 id ordering 순서대로 anchor node를 설정한다.
- 즉 이러한 경우, 전체 batch의 갯수는
total_nodes / batch_size
가 되며, 각 batch마다batch_size
만큼의 anchor node의 수가 설정된다.
- input_nodes를 지정하는 경우 (
input_nodes=[~~]
):batch_size
를 지정하지 않을 시 batch 갯수는len(input_nodes) (== batch_size=1)
가 되고,batch_size
지정 시 batch 갯수는len(input_nodes) / batch_size
가 된다.batch_size
는 (위에서도 동일하게) 한 batch 안에서 고려할 총 anchor node의 개수이다.input_nodes
를 지정했는데batch_size
를 지정하지 않는다면:batch_size = 1
이 되며, 각 batch마다 1개의 anchor node를 설정하게 되고, 이 anchor node는input_nodes
에 있는 node들 만을 가지고 batch를 생성하게 된다.input_nodes
를 지정하고batch_size
도 지정하게 된다면: 각 batch마다batch_size
만큼의 anchor node가 설정되고, 이 anchor node는input_nodes
에 있는 node들만을 가지고 batch를 생성하게 된다.
만일 10개의 node가 있을 때, input_nodes = 2
, batch_size = 2
와 같이 설정하게 된다면batch_size
만큼 input_nodes
에서 anchor node를 선택해 batch 생성하게 되며, 이때는 batch_size==len(input_nodes)
이므로 batch 1개에 2개의 anchor node가 선택되어 batch가 1개만 생성된다.
Example
백문이 불여일RUN, 간단한 예제를 통해 직접 확인해보자.
10 node, 15 edge로 구성된 임의의 그래프를 생성하였다.
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.seed import seed_everything
seed_everything(62)
## Custom Graph (10 node) make
G = nx.Graph()
G.add_nodes_from([0,1,2,3,4,5,6,7,8,9])
G.add_edges_from([(0,1), (0,3), (1,2), (1,6), (2,3), (2,7), (3,5), (4,2), (4,6), (5,7), (6,9), (8, 9), (8,3), (9, 3), (0,9)])
## Color map for nodes
color_map = ['red', 'red', 'green', 'cyan', 'pink', 'pink', 'red', 'pink', 'pink', 'green']
# ## Made graph check by plotting
# plt.figure(figsize=(12, 8))
# nx.draw(G, pos=nx.kamada_kawai_layout(G), node_size=500, alpha=1.0, with_labels=True, node_color=color_map)
# plt.show()
graph가 준비되었다면, 앞서 기술한 인자들을 토대로 batch를 생성해보자.
## PyG transform
data = from_networkx(G)
data.n_id = torch.arange(data.num_nodes) # original node 접근을 위해, n_id attribute 추가.
NUM_NEIGHBORS = [2,3]
BATCH_SIZE = 2
## Loader 생성
loader = NeighborLoader(
data = data,
num_neighbors = NUM_NEIGHBORS,
batch_size = BATCH_SIZE,
)
## sampling 수행
subgraph_list = []
for batch in loader:
subgraph_list.append(batch)
간단하다. 정말...
이제 생성된 computation graph들을 plot해보자.
def batch_wise_graph_plot(subgraph_index_num):
# 해당 위치의 subgraph를 가져오기
graph_in_batch = to_networkx(subgraph_list[subgraph_index_num])
# 변환작업
# 각 batch로 들어갈때, global node id는 사라지고 batch마다 고유한 id를 가지게 되므로
# 이를 original node와 mapping.
transformed_nodes = graph_in_batch.nodes()
original_nodes = subgraph_list[subgraph_index_num].n_id.tolist()
node_dict = dict(zip(transformed_nodes, original_nodes))
# batch로 가면서 변환된 node id를 original node id와 mapping한 node_dict을 통해 subgraph를 re-label.
transformed_graph = nx.relabel_nodes(graph_in_batch, node_dict)
# original graph에서의 node color를 그대로 get.
colormap_sample = []
for position in transformed_graph.nodes():
colormap_sample.append(color_map[position])
return transformed_graph, colormap_sample
# Plot
plt.figure(figsize=(20,16))
plt.subplot(231).set_title("Original Graph")
nx.draw(G, pos=nx.kamada_kawai_layout(G), node_size=400, alpha=1.0, with_labels=True, node_color=color_map)
for i in range(len(subgraph_list)):
subgraph, colors = batch_wise_graph_plot(i)
plt.subplot(int(f'{23}{i+2}')).set_title(f"Anchor Node : {subgraph_list[i].n_id[:BATCH_SIZE].tolist()} \n # of neighbors = {NUM_NEIGHBORS}")
nx.draw(subgraph, pos=nx.spring_layout(subgraph, k=0.9, iterations=40), node_size=400, alpha=1.0, with_labels=True, node_color=colors)
# nx.draw(subgraph, pos=nx.kamada_kawai_layout(subgraph), node_size=400, alpha=1.0, with_labels=True, node_color=colors)
# nx.draw(nx.balanced_tree(3, 3), pos=graphviz_layout(subgraph, prog='dot'), node_size=400, alpha=1.0, with_labels=True, node_color=colors)
plt.tight_layout()
plt.show()
각 subgraph마다 anchor node가 batch_size = 2
개씩 선정되었고, 각 anchor node에서 len(num_neighbors) = 2
-hop 거리의 이웃까지 sampling하게 된다.
위 plot에서 edge가 양 방향인 경우는 선택된 이웃 node에서 1개의 이웃 node를 선택했는데, 그 이웃이 anchor node인 경우이다.
(disjoint 인자를 따로 명시하지 않았으므로 default인 disjoint = False
가 되었기 때문)
Example (extreme case)
sampling이 어떻게 수행되어서 각 batch마다 어떤 computation graph가 생성되는지는 앞선 예제를 통해 알아보았다.
그렇다면, graph가 조금 많이 극단적인 형태라면 sampling이 어떻게 수행될까?
역시 백문이 불여일RUN, 예제를 작성하여 확인해보자.
10 node, 9 edge로 구성된 임의의 그래프를 생성하였다.
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.seed import seed_everything
seed_everything(62)
G = nx.Graph()
G.add_nodes_from([0,1,2,3,4,5,6,7,8,9])
G.add_edges_from([(0,2), (1,2), (1,6), (2,3), (2,7), (3,5), (4,0), (8,3), (9,3)])
color_map = ['red', 'red', 'green', 'cyan', 'pink', 'pink', 'red', 'pink', 'pink', 'green']
# print(G)
# plt.figure(figsize=(12, 8))
# nx.draw(G, pos=nx.kamada_kawai_layout(G), node_size=500, alpha=1.0, with_labels=True, node_color=color_map)
# plt.show()
data = from_networkx(G)
data.n_id = torch.arange(data.num_nodes)
NUM_NEIGHBORS = [2,3]
BATCH_SIZE = 2
loader = NeighborLoader(
data = data,
num_neighbors = NUM_NEIGHBORS,
batch_size = BATCH_SIZE,
)
# sampling 수행
subgraph_list = []
for batch in loader:
subgraph_list.append(batch)
(그나마) 앞선 예제보다 좀 많이 극단적인 형태이긴 하다.
이제 batch 마다 생성된 computation graph를 확인해보자.
def batch_wise_graph_plot(subgraph_index_num):
# 해당 위치의 subgraph를 가져오기
graph_in_batch = to_networkx(subgraph_list[subgraph_index_num])
# 변환작업
# 각 batch로 들어갈때, global node id는 사라지고 batch마다 고유한 id를 가지게 되므로
# 이를 original node와 mapping.
transformed_nodes = graph_in_batch.nodes()
original_nodes = subgraph_list[subgraph_index_num].n_id.tolist()
node_dict = dict(zip(transformed_nodes, original_nodes))
# batch로 가면서 변환된 node id를 original node id와 mapping한 node_dict을 통해 subgraph를 re-label.
transformed_graph = nx.relabel_nodes(graph_in_batch, node_dict)
# original graph에서의 node color를 그대로 get.
colormap_sample = []
for position in transformed_graph.nodes():
colormap_sample.append(color_map[position])
return transformed_graph, colormap_sample
plt.figure(figsize=(20,16))
plt.subplot(231).set_title(f"Original Graph \n node : {G.order()}, edge : {len(G.edges())}")
nx.draw(G, pos=nx.spring_layout(G), node_size=400, alpha=1.0, with_labels=True, node_color=color_map)
for i in range(len(subgraph_list)):
subgraph, colors = batch_wise_graph_plot(i)
plt.subplot(int(f'{23}{i+2}')).set_title(f"Anchor Node : {subgraph_list[i].n_id[:BATCH_SIZE].tolist()} \n # of neighbors = {NUM_NEIGHBORS} \n node : {subgraph.order()}, edge : {len(subgraph.edges)}" )
nx.draw(subgraph, pos=nx.spring_layout(subgraph, k=0.9, iterations=40), node_size=400, alpha=1.0, with_labels=True, node_color=colors)
plt.tight_layout()
plt.show()
함수를 만들어두면 역시 재활용하기가 편해서 좋다 :)
anchor node가 1개와만 연결된 경우(anchor node = 4,5 case), 1-hop 거리에서 현재 이웃이 1개밖에 존재하지 않으므로, 해당 이웃만을 선택하게 된다.
(node 4는 node 0을 선택, node 5는 node 3을 선택)
또한 선택된 node에서 1-hop 거리 (anchor node에서 2-hop 거리)에서 이웃들을 최대 3개까지 선택하게 된다.
(node 3의 이웃 2, 5, 8, 9 에서 2, 5, 9를 선택 // node 0의 이웃은 2, 4 밖에 없으므로 2, 4를 선택)
여기서 확인할 수 있는 점은 이웃의 수가 선택하게 되는 수 보다 작으면 전체 이웃을 선택하게 되고, 이웃 수가 선택할 수 보다 많으면 정해진 수 만큼을 선택한다는 것이다.(어떻게 보면 당연한 이치이다)
즉 num_neighbors
를 지정하여도 graph의 특성 상 각 batch마다 생성되는 computation graph의 크기가 달라질 수 있다는 점을 확인할 수 있으며, 이것이 GraphSAGE에서 제안한 neighborhood sampling의 특징이다.
(말 그대로 정해진 개수 만큼의 node를 무작위로 sampling)
Conclusion
이전 포스팅처럼 작성하다 보니 글이 조금 길어진 듯 하다.
PyG가 역시 Python/PyTorch style을 잘 따르고 있어 실제로 sampling하여 mini-batch를 구성하는 code가 단 몇 줄 만으로 완성되며 상당히 직관적으로 되어있다.
예제를 작성하기 위해 22년, 23년에 작성하였던 testing code를 다시 열어서 확인하니 그 당시의 기억이 새록새록 떠오르는 듯 하다.
이뻐보이는 grpah를 그려보기 위해 node간의 연결성을 여러 가지로 조합도 해 보고, 각 node마다 색깔도 입혀보고...
현재 포스팅은 온전히 NeighborLoader를 통해 mini-batch가 어떻게 형성되는지를 기술하였는데, 실제 GNN code를 작성하여 mini-batch train을 수행하는 과정은 기존 PyTorch code와 동일하다:
다음 포스팅은 GNN field에서의 MNIST라고 불리는 Cora 데이터 셋과 NeighborLoader를 이용해 간단한 실험을 진행했던 내용을 기술하고자 한다.
E.O.D.
'학업 이야기 > Programming' 카테고리의 다른 글
PyTorch Geometric(PyG) - Cora EDA & mini-batch testing(with NeighborLoader) (1) | 2025.02.17 |
---|---|
PyTorch Geometric(PyG) 개요 (0) | 2024.08.05 |