그래프 튜토리얼 7화 : GraphSAGE
- GraphSAGE의 neighbor sampling
- Inductive Learning
- Reddit 데이터로 node classification task 수행
안녕하세요, GUG 여러분! 그래프 튜토리얼의 가이드 Erica에요🍏
7번째 여행에서는 GraphSAGE 모델을 다룰 거에요. GraphSAGE는 neighbor sampling 전략을 통해, 학습 중에 본 적 없는 node에 대해서도 추론이 가능한 Inductive Learning 에 특화되어 있답니다.
📢 가독성을 위해 본문에는 코드 전문을 임포트하지 않았어요.
튜토리얼 코드 전체를 한번에 실행해 볼 수 있는 .ipynb 파일은 아래 링크에 있습니다.
Graph Travel 튜토리얼이 도움이 되고 있다면,
꾸준한 업로드의 원동력이 되도록 STAR🌟를 눌러주세요~!
PyTorch Version: 2.0.0+cpu
Torch Geometric Version: 2.3.1
Key Question
거대한 사이즈의 Graph Data를 미니 배치로 학습하려면, 노드를 어떻게 선택해야 할까요?
위 질문은 GraphSAGE 모델의 핵심 아이디어이자, 키 포인트이기도 합니다. 우리는 지금부터 모델의 학습 원리를 통해 위 질문에 대한 답을 하나씩 찾아 나갈 거예요.
다른 GNN 모델들과 마찬가지지만, GraphSAGE의 주요 아이디어 또한 이웃 노드의 정보를 집계해 target 노드의 임베딩을 업데이트하는 것입니다.
다만, GraphSAGE는 주변 노드의 정보를 집계하여 현재 노드의 임베딩을 업데이트하는 방식을 사용해요. 이를 통해 그래프의 노드 표현을 학습하는 inductive learning이 가능해집니다.
Neighbor Sampling
Graph-based neural networks에서는 일반적으로 모든 이웃 노드에 대한 정보를 집계합니다. 하지만 이 방식에는 두 가지 문제가 있어요.
- 매 layer를 거치는 hop마다 계산량이 기하급수적으로 증가한다.
2. node centality가 높은 hub node가 있을 시, 계산량이 기하급수적으로 증가한다.
따라서 computing cost가 너무나 많이 소비되고, 투입되는 자원에 비해 성능 효율은 떨어지는 상황이 발생하죠. 이런 문제는 거대한 Graph Data일수록 심각해집니다. 이를 해결하기 위해 고안된 모델이 바로 GraphSAGE인 것이죠.
GraphSAGE는 이 문제를 해결하기 위해, 각 layer마다 지정된 개수의 이웃만 샘플링합니다. 이 샘플링은 여러 layer(=hop)에 걸쳐 진행되고, 각 layer에서는 이전 layer의 이웃을 기반으로 샘플링됩니다. 이 방식의 장점은 scalability와 normalization입니다.
각 layer에서 샘플링되는 이웃의 수가 고정되어 있기 때문에, layer 깊이와 관계없이 연산 복잡도가 일정하게 유지됩니다. 따라서 그만큼의 확장성을 갖추게 되죠. 대규모 그래프에도 학습이 효율적으로 이루어집니다.
또한 무작위로 이웃을 샘플링하는 것은 학습 중에 약간의 노이즈를 도입하는 효과가 있습니다. 이는 모델이 과적합을 방지하는 데 도움이 되기도 합니다.
다만, 무작위 샘플링을 진행하기 때문에 학습 과정이 더 stochastic(확률적)합니다. 따라서 샘플링을 통해 도출한 정보를 잘 Aggregation하는 것이 중요하죠.
Aggregation?
각 노드에서 이웃의 정보를 집계하는 방식에는 3가지가 있습니다. 바로 mean, LSTM, pooling 방법이죠. 이렇게 집계된 정보와 target 노드의 정보를 결합해 새로운 노드 임베딩을 생성하는 과정이 필요합니다.
- Mean Aggregator: 이웃의 특징 벡터의 평균을 계산합니다. target 노드의 새로운 임베딩은 이웃의 임베딩 평균에 기반한 벡터와 현재 노드의 임베딩을 결합하여 얻습니다. 계산적으로 가장 간단합니다.
- LSTM Aggregator: 이웃들의 특징 벡터를 순서에 따라 LSTM에 통과시킵니다. LSTM은 시퀀스 데이터를 처리하는 데 특화된 신경망이므로, 이웃의 정보를 시퀀스로 처리합니다. 그래프의 구조가 순서를 갖는 경우 LSTM 집계가 특히 효과적이에요.
- Pool Aggregator: 각 이웃의 feature vector에 독립적인 신경망을 적용하고, 그 결과를 집계(예: max pooling)해 target 노드의 새로운 표현을 얻습니다.
Inductive Learning?
그렇다면 위에서 배운 컨셉과 핵심들이, 구체적으로 Inductive Learning과는 어떤 연관이 있는 걸까요? 이것에 대해 알아보기 전에, 우리는 먼저 Inductive Learning의 본질과 핵심 아이디어를 먼저 파악할 필요가 있겠네요.
Inductive Learning의 핵심 아이디어는 train data에 포함되지 않은 새로운 샘플에 대해 일반화 능력을 갖춘 모델을 학습하는 것입니다.
반면, Transductive Learning은 학습 데이터에 포함된 샘플들 사이의 관계를 학습하는 데 중점을 둡니다.
GNN에서의 Inductive Learning은 그래프가 동적으로 변화하고, 새로운 노드나 엣지가 추가될 수 있다는 점에서 가능한 접근입니다. GraphSAGE의 핵심 요소와 Inductive Learning 사이의 연관점은 다음과 같아요.
- Neighbor Sampling: GraphSAGE의 neighbor sampling 전략은 각 노드의 지역적인 이웃 정보만을 기반으로 노드의 임베딩을 업데이트합니다. 이로 인해, 학습 도중에 보지 못한 새로운 노드가 나타나더라도 해당 노드의 지역적인 이웃 정보를 사용해 즉시 임베딩을 생성할 수 있어요.
- Aggregator Functions: 집계 함수(Aggregator)는 여러 이웃의 정보를 통합하는 방식을 정의합니다. Inductive Learning에서는 이 집계 함수가 새로운 노드와 그 노드의 이웃들 사이의 관계를 모델링하는 데 중요한 역할을 합니다. 즉, 집계 함수는 새로운 노드가 그래프에 추가될 때 그 노드의 임베딩을 도출하는 방법을 제공합니다.
- Fixed-size Representations: GraphSAGE는 고정된 크기의 임베딩 벡터를 생성합니다. 따라서, 새로운 노드나 이웃이 추가되더라도 임베딩의 크기는 일정하게 유지됩니다. 이는 Inductive Learning에서 중요한 특성으로, 새로운 노드에 대해 일관된 형식의 표현을 제공하며 이를 다른 머신러닝 작업에 쉽게 통합할 수 있답니다.
이렇듯 GraphSAGE의 핵심 개념은 그래프 내의 새로운 node나 Subgraph에 즉시 적응하고, 임베딩을 생성하는 데 중점을 둡니다. 결과적으로 GraphSAGE는 학습 시점에서 보지 못한 새로운 노드나 구조에 대해 임베딩을 생성하는 Inductive Learning에 탁월한 성능을 보이는 거에요.
Inductive Learning with Reddit
Reddit 데이터셋은 게시물들을 node, 유사한 커뮤니티 내의 게시물 간의 관계를 edge로 나타냅니다. 각 게시물은 해당 게시물의 내용을 나타내는 단어 임베딩을 특징으로 가지며, 해당 게시물이 속한 커뮤니티를 나타내는 label을 가집니다.
우리는 node classification를 통해, 각 게시물의 내용과 target 게시물의 neighbor 게시물들의 정보를 기반으로 해당 게시물이 어떤 커뮤니티에 속하는지 예측할 수 있습니다. 그 과정에서 본 적 없는 node에 대한 Inductive Learning도 가능하죠.
이번 여행의 튜토리얼 코드에서는 Reddit 데이터셋을 사용하여 GraphSAGE 모델로 인덕티브 학습을 구현, 노드 분류를 진행하겠습니다. 해당 게시물의 커뮤니티 label을 예측해내는 task를 수행해 보도록 해요!
학습 코드를 작성하기 전에, Reddit Dataset이 과연 Inductive Learning에 적합한지의 여부를 확인해 보겠습니다. 위 코드는 train과 test 데이터 사이에 본 적 없는 노드가 있는지를 체크해요.
우리가 원하는 출력 결과는, train과 test 사이의 노드 중복도가 낮은 것입니다.
그리고 코드의 실행 결과는... 55703개의 본 적 없는 노드가 test set에 존재한다고 하네요! 충분한 개수의 unseen node가 있다는 것을 알았으니, 이제 다음으로 넘어가 코드를 구현해 봅시다.
dataset = Reddit(root='./data/Reddit')
data = dataset[0]
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels, aggr="mean")
self.conv2 = SAGEConv(hidden_channels, out_channels, aggr="mean")
self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
self.bn2 = torch.nn.BatchNorm1d(out_channels)
def forward(self, x, adjs):
# 첫 번째 convolution 층
x = self.conv1(x, adjs[0].t())
x = self.bn1(x)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# 두 번째 convolution 층
x = self.conv2(x, adjs[-1].t())
x = self.bn2(x)
x = F.relu(x)
return x
우선, 데이터셋을 로드한 뒤 GraphSAGE 클래스를 정의합니다. 간단한 튜토리얼 코드이므로 2개의 convolution layer를 사용할 거에요.
self.conv1
를 사용해 입력 데이터 x
와 주어진 인접 리스트 adjs[0]
에 대한 convolution을 수행하는 걸 코드 내에서 확인할 수 있죠.
Batch Normalization과 dropout을 적용하고, ReLU로 비선형성을 추가합니다.
def train():
model.train()
total_loss = 0
for batch_size, n_id, adjs in train_loader:
adjs = [adj.to(device) for adj in adjs]
optimizer.zero_grad()
out = model(data.x[n_id].to(device), adjs)
loss = F.cross_entropy(out, data.y[n_id[:batch_size]].to(device))
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
학습 코드 부분에서는 예측값과 실제 라벨로부터 Cross Entropy Loss를 계산하는 부분을 주목해볼 수 있겠네요. softmax 대신 차용한 방법입니다.
train_loader
도 중요한 부분입니다. 우리가 이 다음 코드에서 구현할 loader를 통해, 노드와 노드의 이웃들을 효율적으로 neighbor sampling 해오게 되는 실질적인 코드라인이거든요.
train_loader
에는 batch size, node ID, adjs가 들어 있습니다. adjs가 필요한 이유는 정보 집계와 Aggregation 시, 노드의 이웃들을 파악하기 위해서에요.
train_loader = NeighborSampler(data.edge_index, sizes=[10, 10], batch_size=1024, shuffle=True, num_nodes=data.num_nodes)
test_loader = NeighborSampler(data.edge_index, sizes=[10, 10], batch_size=1024, shuffle=False, num_nodes=data.num_nodes)
가장 먼저, NeighborSampler
는 앞서 말했듯 그래프 학습에서 사용되는 데이터 로더입니다. 이것은 노드와 그 이웃들을 효율적으로 neighbor sampling하는 역할을 합니다.
sizes=[10, 10]
는 각 단계별로 몇 개의 이웃을 샘플링할지를 나타내요. 이 코드에서는 2 hop 이웃 샘플링을 수행하며, 각 layer 층마다 10개의 이웃을 샘플링한답니다.
이렇게 어느덧 대형 그래프를 다룰 수 있는 GraphSAGE 모델까지 다루게 되었네요. 정말 놀랍고 재미있지 않나요? Inductive learning이라면, Recommend System의 주요 문제 중 하나인 Cold-start Problem을 완화하는 데에도 도움이 된답니다.
여러분이 이쯤에서 GNN의 핵심이 '노드 정보 집계를 더 효율적으로 하는 방법론 찾기'라는 사실을 눈치채셨길 바라요.
🎫 Autoencoder?
정보 집계를 효율적으로 하는 거라면, ML 분야에서도 강력한 친구가 있죠. 맞아요. 바로 Autoencoder랍니다! 이 이름을 모르는 분은 없을 거라고 생각해요.
이 알고리즘을 Graph에도 적용하려는 시도가 당연히 존재했답니다.
다음 시간에는, 그래프의 복잡한 구조를 효과적인 latent representation vector (잠재 표현 벡터)로 압축해내는 GAE 모델에 대해 배워보도록 해요!