학업 이야기/Data Science

Graph Representation Learning (2): Graph Neural Networks

행복한 하늘 2024. 7. 17. 21:27

 

 

Graph Machine Learning

Graph ML, random walk approach에 관한 글은 이전 포스팅에 기술되어 있다.

 

 

 

Graph Neural Network

random walk를 이용한 접근 방식에도 당연히 문제점이 존재했다.

 

  • embedding 계산을 위해 graph에 존재하는 전체 node 수 만큼의 복잡도를 가진다.
  • node마다 가지고 있는 정보(feature, attribute)를 온전히 활용하지 못한다.
  • Transductive한 방법이기에 학습하는 과정에서 unseen node에 대한 embedding 계산을 하지 못한다.

 

따라서 encoding을 수행하는 mapping function $f$를 신경망을 통해 진행하자는 연구가 진행되었고, GNN(Graph Neural Network)의 연구가 진행되었다.

 

널리 활용되던 CNN/RNN을 graph 데이터에 적용하기가 힘들었는데, 이전 포스팅에서 언급한 바와 같이 graph는 이미지나 텍스트와 같이 grid, sequence 형태로 표현하기가 힘드며, graph를 구성하는 node가 고정된 순서/기준이 존재하지 않아 관점에 따라 다른 모양이 될 수도 있다는 것이 특징이다.

(이러한 특징으로 인해 주어진 두 그래프가 동형인지를 확인하는 Weisfeiler-Leman(WL) test가 있다.)

 

graph 데이터에 신경망을 활용하기 위한 Naïve한 방법으로 다음과 같은 설계를 해볼 수 있을 것이다.

인접 행렬과 feature를 결합하여 신경망의 input으로 전달 (from CS224W)

 

이전처럼 traditional하게, graph의 구조적인 정보를 담고 있는 인접 행렬(adjacency matrix)과 각 node마다 가지고 있는 특징 정보(feature vector)CONCAT한 후, 이를 신경망의 입력으로 전달하는 방법이다.

 

어찌 보면 graph의 구조적인 정보도 학습할 수 있고, 각 node가 가지고 있는 feature 정보 또한 학습할 수 있을 것이다.

 

하지만, 당연히 이는 아래와 같은 문제점이 존재한다.

 

  • graph에 존재하는 모든 node를 처리하기 위해 신경망의 입력, 즉 parameter 수가 많아지게 된다.
  • graph의 크기가 달라지거나 다른 형태의 graph가 주어진다면 동일한 model을 활용할 수 없게 된다.
  • node의 순서가 변경되는 경우에 sensitive해지게 된다.
    (graph의 구조가 변경되거나, 전처리 과정에서 순서가 변경되거나, 일반적인 mini-batch training을 하는 경우 node의 순서가 변경될 수 있다.)

 

그렇다면 어떻게 graph의 구조에 sensitive하지 않으면서 신경망의 parameter 수를 적절히 handle할 수 있을까?

 

CNN의 경우를 생각해보자.

일반적인 CNN의 학습 과정 (from CS224W)

 

이미지가 입력으로 주어졌을 때, CNN은 filter를 이동시키는 일명 sliding window라는 기법을 이용해 이미지 전체에 대한 feature를 학습하게 된다.

 

주어진 graph가 이미지라고 생각해본다면, CNN에서 사용한 sliding window와 같은 기법을 이용해 graph의 지역적인 구조에 대한 학습을 차근차근 진행할 수 있을 것이다.

 

하지만 graph는 지역성(locality)에 대한 개념이 애매모호하며, sliding window 기법을 바로 적용하기 힘든 문제가 있다.

 

graph에 sliding window 기법을 적용하기가 힘든 이유 (from CS224W)

 

위 예시처럼 graph는 그 구조가 아주 다양하며, sliding window를 적용한다 하더라도 매번 달라지는 node의 순서와 그 갯수로 인해 적용이 상당히 어려울 것이다.

 

 

Graph Convolutional Network(GCN)

(GCN paper review 포스팅)

 

기존의 CNN은 filter라는 개념을 이용해 이를 sliding하면서 feature extraction 및 학습을 진행하였다.

 

즉, 여러 개의 grid로부터 feature를 추출 및 결합하여 한층 더 압축된 feature를 생성하였고, 이를 반복하였다.

 

이에 intuition을 받아 다음과 같은 접근이 제안되었다.

 

CNN과 유사하게 인접 grid(node)로부터 feature를 추출 및 결합 (from CS224W)

 

하나의 node를 기준으로, 인접한 node로부터 정보를 모으는 방식이다. (neighborhood aggregation)

 

즉, 각 node마다 연결된 이웃 node들의 feature(message)를 추출 및 결합하여 학습을 진행하는 것이다.

 

이 과정을 그림으로 표현하면 다음과 같이 표현할 수 있다:

 

neighborhood aggregation (from CS224W)

 

신경망을 이용하여, 각 node마다 연결된 모든 이웃 node의 feature를 결합하고 변형을 진행한다.

 

위 그림에선 node A를 기준으로 연결된 node인 node B, C, D의 feature를 결합 및 변형하고, 각 node B, C, D는 마찬가지로 연결된 node들로부터 feature를 결합 및 변형한다.

 

즉, 하단 그림처럼 각 node마다 연결된 node들을 이용해 일종의 계산 그래프(computation graph)를 형성하여 node마다 feature를 추출 및 결합하여 학습을 진행하는 방법이다.

 

각 node마다 feature aggregation을 위해 형성되는 computation graph (from CS224W)

(CS224W가 정말 양질의 자료...)

 

GCN paper review 포스팅에서도 언급했듯이, (어떻게 보면) 이게 전부다.

 

feature aggregation을 위해 layer(Graph Convolution layer)를 $n$개 적재하여 기준 node로 부터 $n$-hop 이웃 node 까지의 computation graph를 생성 $\rightarrow$ feature aggregation(neighborhood aggregation, message passing) 진행 $\rightarrow$ 기준 node의 feature update, 학습 진행

 

전체적인 이 과정을 수식으로 표현하면 다음과 같다:

 

GCN propagation rule (from CS224W)

 

이전 layer, 즉 현재 target node를 기준으로 연결된 모든 이웃 node의 representation/feature를 결합한 후, layer의 weight를 적용한 후 non-linear activation을 취하여 현재 layer(계산/갱신 될 현재 target node)의 representation/feature를 update한다.

 

정말 이게 전부다 🙂

 

다시 한번 정리해보면 다음과 같이 정리해볼 수 있을 것이다.

 

  1. 현재 target node를 기준으로 연결된 모든 이웃 node를 탐색
  2. 이웃 node들의 representation을 aggregate
  3. aggregate 된 이웃 node들의 representation과 target node의 이전 representation을 결합하여 target node의 representation을 update

 

GCN paper에서는 위 수식과 같은 vector representation이 아닌 matrix representation으로 설명하고 있는데, matrix를 분해하여 계산보면 결국 위 수식과 같은 summation(average)의 연산을 수행하는 것을 알 수 있다.

 

GCN paper에 기재된 propagation rule

 

이전 DeepWalk, node2vec과 같은 random walk 기반의 방법과는 다르게 별도의 전처리 과정 수고가 줄어들게 되었고, 기존의 CNN/RNN과 같은 신경망 처럼 layer를 적재하고 parameter tuing을 수행하여 graph에 대한 representation을 보다 간편하고 쉽게 구할 수 있게 되었다.

(GCN paper review 포스팅 을 보면 기존의 random walk 기반의 방법들에 비해 downstream task의 성능이 훨씬 더 잘 나오는 것을 확인할 수 있다.)

 

 

GraphSAGE

(GraphSAGE paper review 포스팅)

 

GCN은 획기적이면서도 빠르고 간편한 representation을 생성할 수 있는 방법을 제안했다.

 

하지만 마찬가지로 GCN에도 치명적인 문제가 있었는데, 한 target node를 기준으로 연결된 모든 이웃 node의 feature aggregation을 진행하는 것이다.

 

어찌 보면 feature aggregation을 위해서라면 당연히 필요한 과정이지만, 이는 large graph에서 scalable하지 못하기가 쉽다.

(계산 복잡도가 지수적으로 증가하는 문제)

 

그러하여 GraphSAGE는 GCN을 변형하여 large graph에서도 동작할 수 있는 방법을 제시하였고, (결과적으로 보면) GCN에 비해 좋은 성능을 달성할 수 있음을 보여주었다.

 

GraphSAGE의 연산 과정을 수식으로 보면 다음과 같다:

 

GraphSAGE propagation rule (from CS224W)

 

GCN과 차이점이 뭘까?

 

선택된(sampling 된) 이웃 node들의 representation을 aggregate 한 후에($\mathrm{AGG}$), target node의 representation과 CONCAT하여 non-linear activation을 거쳐 현재 layer(계산/갱신 될 현재 target node)의 representation/feature를 update한다.

 

$\mathrm{AGG}(\cdot)$으로는 GCN과 같은 MEAN을 사용할 수도 있고, paper에선 permutation invariant한 pooling 또는 LSTM을 사용할 수 있음을 소개하였다.

 

또 다른 핵심은 앞서 언급한 "선택된(sampling 된) 이웃 node들" 인데, 단순히 target node와 연결된 모든 이웃 node가 아닌 그 중 일부분만을 random하게 sample하여 aggregation에 활용한다는 것이다.(neighborhood sampling/node-wise sampling)

(sampling 관련 부분은 GraphSAGE paper review 포스팅에 좀 더 상세하게 기재되어 있다.)

 

 

Graph Attention Network (GAT)

GCN의 한계점을 해결하기 위해 GraphSAGE가 등장하였었다.

 

하지만 GCN, GraphSAGE에도 문제점이 있다.

 

바로 "각 node의 중요도"를 충분히 고려하지 않았다는 점이다.

 

앞서 기술한 수식들만 봐도 단순히 이웃 node들의 representation을 모아 평균(mean)을 취하고 있는데, GAT에서는 이 문제점을 지적한 것이다.

(실제 SNS만 보아도 각 사용자의 영향력, 즉 node마다 가지고 있는 고유한 중요도가 다를 수 있다.)

 

따라서 GAT에서는 (당시 화제가 되었던) attention mechanism을 이용해 각 node마다의 중요도를 다르게 계산하여, feature aggregation을 수행할 때 그 중요도를 각기 다르게 부여해 representation을 보다 더 정교하게 다듬을 수 있는 방법을 제안하였다.

(attention mechanism에 대한 설명은 정말 잘 정리되어 있는 문서가 있기에 그것으로 대체하고자 한다.)

 

수식을 보면 다음과 같다.

GAT propagation rule

 

단순히 각 node마다의 중요도를 다르게 두기 위한 attention weight가 추가되었는데, 이 attention weight를 보면 다음과 같다.

 

attention weight를 계산하는 과정 (from CS224W)

 

attention mechanism에서 계산하는 attention weight의 계산과 거의 동일하다.

 

말 그대로 node 간의 중요도를 계산하여 그 중요도에 따라 representation 계산에 차등을 주어야 한다 라는 접근이다.

 

기존의 attention 처럼 multi-head attention을 수행할 수 도 있는데, 이는 위의 attention 연산을 head의 수 만큼을 진행하여 각기 다른 attention weight를 계산하여 각 head에서의 reprsentation을 계산한 후, 이를 CONCAT하여 최종 representation을 계산하는 것이다.

 

GAT paper를 보면 실제로 기존 방법들에 비해 우수한 성능을 보여준 것을 알 수 있다.

 

하지만 attention 연산이 들어간 시점에서, (attention 연산을 수행하는 거의 모든 architecture와 마찬가지로) 연산에 있어 상당한 complexity가 있다는 점을 짐작할 수 있다.

(이는 추후 포스팅 할 paper에서도 자주 언급되는 문제이다.)

 

 

 

Conclusion

Graph representation learning을 알기에 앞서 기존의 representation learning이 무엇인지, 어떻게 발전되었는지 간략하게 짚어보았고, graph domain에서 representation learning이 어떻게 이루어지는지를 간략하게 정리하였다.

 

기존엔 random walk 기반의 representation learning이 주 를 이루었지만 그 한계점이 존재하였고, 이를 해결하기 위해 신경망(GNN)을 활용하는 방법이 제안되고 발전됨에 따라 graph domain에서의 representation learning은 (아마도 거의) GNN을 통해 이루어진다고도 볼 수 있다.

 

실제로 graph domain의 대표격이라 볼 수 있는 생물/화학 분야(molecule, PPI 등)에서 GNN/GT(Graph Transformer)와 같은 신경망 기반(deep-learning based)의 방법이 주를 이루고 있음을 알 수 있다.

 

다시 한번 일반적인 GNN의 연산 과정을 살펴보면 다음과 같이 정리할 수 있을 것이다:

image by google research

 

  • 인접 행렬(adjacency matrix)을 이용해 target node에서 (sampling 된 or 모든) 연결된 이웃 node를 탐색
  • 이웃 node의 representation을 집계 (neighborhood aggredation/message passing)
  • 현재 target node의 representation과 집계된 이웃 node의 representation을 결합
  • layer의 weight와 non-linear activation을 거쳐 target node의 새로운 representation을 계산

 

추가적으로, GNN을 공부하던 시절 도움이 되었던 자료를 함께 첨부하였다:

 

Stanford CS224W: Machine Learning with Graphs (개인적으로 강추하는 자료)

PyTorch Geometric(PyG) Tutorial (DGL보다는 조금 늦게 개발되었지만 기존 PyTorch와 거의 동일하게 간편한 PyG 라이브러리에서 제공하는 튜토리얼 영상/자료)

Graph Representation Learning textbook (GraphSAGE 저자가 발간한 pdf)

CS224W 정리본 (wandb에서 작성)

 

 

 


 

꼭 한번 다시 정리해야지 하는 내용을 간략하게나마(?) 정리하였다.

(이전 포스팅과 마찬가지로 어디까지나 개인적으로 아카이빙, 정리 용으로 작성한 글 이기에 다소 오류/문제점이 있을 수 있습니다. 언제든지 지적해주시면 감사하겠습니다.)

 

GNN을 처음 접하고 공부하던 시절엔 인터넷에 자료가 그렇기 많지 않았고, PyG 또한 개발이 막 시작되던 때라 자료를 검색하면 십중팔구 외국 자료가 많이 나왔었기에 당시에 많은 도움을 받았던 자료들을 함께 첨부했다.

 

포스팅을 작성하고 있는 지금 다시 검색해보면 한글로 된 여러 자료가 많이 나오는데, 그 몇 년 사이에 국내에서도 GNN에 조금이나마(?) 붐이 오긴 왔었던 듯 하다.

 

시간이 지나면서 조금 흥미로웠던 점은 학계에서는 PyG를 많이 선호하는 듯 하고, 산업에서는 DGL을 많이 선호하는 듯 하다는 것이다.

(아무래도 Jure 교수님이 거의 대놓고 PyG를 support하기도 하고, Amazon에서 DGL을 support하기도 하고...)

 

예전 GNN 관련 프로젝트를 진행하기에 앞서 프로젝트에 참가하는 분들에게 간략하게 소개하였던 자료도 함께 첨부하였다.

 

Graph_representation_learning_GNN_overview.pdf
0.93MB

 

 

E.O.D.