그래프 튜토리얼 6화 : Graph Attention Network (GAT)

그래프 튜토리얼 6화 : Graph Attention Network (GAT)
  1. Graph Attention Network 구현
  2. Amazon 상품 카테고리 예측하기

안녕하세요, GUG 여러분! 그래프 튜토리얼의 가이드 Erica에요🥁

6번째 여행에서는 GAT 모델을 다룰 거에요. GAT는 Transformer 모델처럼 Attention Mechanism을 활용하고, 그만큼 뛰어난 유용성과 범용성을 지니고 있답니다.

📢 가독성을 위해 본문에는 코드 전문을 임포트하지 않았어요.
튜토리얼 코드 전체를 한번에 실행해 볼 수 있는 .ipynb 파일은 아래 링크에 있습니다.

GitHub - H4Y3J1N/Graph-Travel: with GUG, Let’s explore the Graph Neural Network!
with GUG, Let’s explore the Graph Neural Network! Contribute to H4Y3J1N/Graph-Travel development by creating an account on GitHub.

Graph Travel 튜토리얼이 도움이 되고 있다면,
꾸준한 업로드의 원동력이 되도록 STAR🌟를 눌러주세요~!

PyTorch Version: 2.0.0+cpu
Torch Geometric Version: 2.3.1

Attention?

오늘 이 튜토리얼을 통해서는 [Attention Mechanism, Self Attention, Multi Head Attention, Attention Score]가 각각 무엇인지, 그리고 어떤 연관으로 GAT가 작동하는지 큰 그림을 파악하는 통찰력을 키우게 될 거에요.

Main Concept of GAT

가장 먼저, GAT 모델에 대해 이야기할 때 빠질 수 없는 주요 개념 단어 4가지에 대해 간략하게 알아보겠습니다. 이 개념을 먼저 파악하는 게 중요한 이유는, Attention이라는 단어가 빈번하게 사용되기 때문에, 개념 단어가 제대로 확립되어 있지 않으면 모델 작동의 원리를 이해할 때 혼란스럽기 때문이에요.

  1. Attention Mechanism : 각 노드가 이웃 노드와 얼마나 연관이 있는지, 여러 노드 중 어떤 노드가 중요한지를 판단하는 방법론입니다.
  2. Self Attention : 노드 특성 정보를 업데이트할 때, 본인 노드의 정보도 고려하는 방법입니다.
  3. Multi Head Attention : Attention Mechanism을 여러 개 병렬로 사용하여, 다양한 측면에서 정보수집을 하는 방법입니다.
  4. Attention Score : 노드와 그 이웃 노드들 사이의 상대적 중요성을 측정하는 값으로, 이를 통해 정보를 어떻게 통합할지 결정하는 값입니다.

지금까지의 튜토리얼 여행을 통해서, Graph Neural Network의 다양한 알고리즘을 관통하는 주요 포인트가 [노드 정보를 취합하는 방법론]이라는 걸 눈치채셨을 거에요.
이번 GAT 모델을 공부할 때도 그 점을 꼭 염두에 두고, Attention이란 개념이 어떻게 정보를 취합해내는지를 확인해 보세요!

Overall learning order

Graph Attention Layer가 뭔지, 수식과 함께 진부하게 설명하는 방법도 있겠지만... Erica가 가장 중요하게 생각하는 튜토리얼 공부의 목적은 모델 작동의 원리와 프로세스를 완벽히 이해시키고, 그를 통해 베이스라인 고도화를 혼자서도 가능하게 만드는 것입니다.
그러기 위해 세부 항목을 깊게 파고들기 전, 모델 작동의 프로세스와 각 단계의 역할을 전체적으로 조망하도록 할게요.
세부적으로 나눠보았을 때, GAT의 전체 학습 순서는 아래와 같은 11단계로 나누어집니다.

  1. Input Preparation
  2. Linear Transformation
  3. Concatenation
  4. Weight Application
  5. Activation Funciton
  6. Softmax Normalization
  7. Aggregation
  8. Multi Head Attention
  9. Layer stacking
  10. Loss Calculation & Back Propagation
  11. Prediction & Evalution

각 단계가 전체 학습에서 어떤 역할을 하는지, 지금부터 하나씩 차근차근 알아봅시다.

1. Input Preparation

모델 학습을 위해, Graph Structure과 Node Feature가 입력으로 주어집니다.

2. Linear Transformation

Node Feature를 선형 변환합니다. 선형 변환이란 Node Feature에 가중치 W를 곱하는 것을 말해요. 따라서 선형 변환을 했을 때, 최종적으로는 가중치 Wmatrix와 node Feature Matrix인 X가 곱해진 형태가 됩니다. WxX 처럼 말이죠. 가중치 행렬 W는 X를 회전시키고 스케일링해서 새로운 공간으로 매핑하는 역할을 합니다. 그로 인해 원래의 특성보다 더 유용한 정보를 추출할 수 있게 됩니다.

3. Concatenation

선형 변환된 두 개의 노드, i와 j의 Feature Vector를 concatnate합니다. 두 노드의 feature 정보를 동시에 고려함으로써, 두 노드 사이의 관계를 더 잘 고려할 수 있게 되죠.

4. Weight Application

Concatenate된 vector에, 어텐션 스코어 계산을 위한 가중치 벡터
a를 곱합니다. 이 벡터 a는 모델 학습 과정에서 최적화되며, 각 노드 쌍에 대해 어떤 feature가 더 중요한지 학습합니다. 이 작업은 노드 i, 노드 j의 유사도를 계산하는 일종의 방법이라고 볼 수 있습니다.

  • 유사도 계산이라구요? 그게 무슨 뜻인가요?
    위와 같은 질문이 생길 수도 있겠네요. 이 말을 이해하려면, Attention Mechanism의 목적과 작동 방식을 먼저 잘 생각해봐야 합니다. 어텐션 매커니즘은 여러 노드 중, 어떤 노드가 중요한지를 판단합니다. 그리고 노드의 중요도를 Attention Score로 나타내죠. Attention Score는 노드 i와 노드 j의 Feature Vector가 얼마나 유사한지, 또는 얼마나 관련성이 큰지, 다시 말해 "얼마나 중요한지"를 측정하는 지표로 볼 수 있습니다.

5. Activation Funciton

Weight Application을 통해 얻은 스칼라 값에, Activation Function을 적용합니다. ReLU나 Leacky Relu, eLU 등을 사용할 수 있죠. 비선형성을 도입함으로써 모델은 더 복잡한 패턴을 학습할 수 있게 됩니다. Attention Score의 계산은 이런 방식으로 진행됩니다.

6. Softmax Normalization

이렇게 얻어낸 Attention Score를 비교하기 위해서는 한 단계의 작업이 더 필요합니다. 바로 비교 가능하도록 정규화를 하는 것이죠. 그래서 우리는 Attention Score에 Softmax Normalization을 진행합니다.

Softmax는 값을 0과 1사이의 실수로 제시하고, 합치면 1이 된다는 특징을 가지고 있습니다. 이 값은 확률로도 이해할 수 있죠. 정규화를 거친 다음에는 비교를 통해 각 노드가 다른 노드에 대해 얼마나 중요도가 높은지 명확하게 파악할 수 있습니다.

7. Aggregation

softmax normalization을 거친 Attention Score를 Attention Weight라고 부르기도 합니다. 이 값은 각 이웃 노드의 정보가 현재 노드에 얼마나 영향을 주는지의 정도로 해석할 수도 있어요. 왜냐하면 Attention Weight를 이용해 이웃 노드들의 정보를 현재 노드에 통합(Aggregation)하기 때문입니다.

즉, 이웃 노드의 특성 정보를 Attention Weight로 가중평균(Weighted Sum)해서 target 노드의 Feature Vector를 계산, 업데이트합니다. 바로 이때 Self-Attention이라는 개념도 함께 사용됩니다. 업데이트 시, 자기 자신의 정보도 반영한다는 것이죠.

8. Multi Head Attention

지금까지 강의한 일련의 과정 하나를 Attention Mechanism이라고 합니다. Multi Head Attention이란 개념도 이젠 훨씬 이해하기 쉬워졌네요. 멀티 헤드 어텐션이란 병렬로 여러 개의 Attention Mechanism(=Head)을 사용하는 것입니다. 그리고 그 결과들을 Concatenation이나 Average를 통해 하나로 합칩니다.
더 쉽게 풀어 이야기하자면, 기본적인 어텐션 매커니즘은 한 번에 하나의 표현(representation)만 학습하지만, Multi-Head Attention은 여러 개의 독립적인 어텐션 "헤드(heads)"를 사용하여 동시에 여러 종류의 표현을 학습합니다. 아래와 같은 3가지의 특징을 가기죠.

  1. 병렬 : 노드 i, 노드 j 한 쌍, concat된 하나의 입력에 대해 여러 개의 Attention Head를 동시 적용합니다.
  2. 가중치 분리 : 각 Head는 서로 독립적인 가중치 W, 가중치 벡터 a를 사용합니다. 따라서 각 Head는 서로 다른 특성을 추출합니다.
  3. 어텐션 연산 : 각 헤드는 독립적으로 Attention Score를 계산합니다. 이렇게 하면 모델은 동시에 여러 종류의 관계나 패턴을 포착할 수 있죠.
  4. 결과 통합 : 그렇게 얻어낸 다양한 측면의 정보를 하나로 통합합니다. Concatenate나 Average를 통해서요.

9. Layer stacking

이런 Attention 계정을 여러 개 쌓는 것이 가능합니다. 하나의 Attention 계층의 output이 다음 계층의 input이 되는 형식이죠. 그리고 모델의 최종 출력은 분류, 회귀, 노드 임베딩 등 다양한 형태를 취할 수 있습니다.

자, 이렇게 GAT의 주요 개념과 전체 학습 프로세스를 이해하는 시간을 가졌습니다. 더 이상 헷갈리는 일은 없겠네요. 10과 11은 너무 기본적이고 쉬운 내용이니 굳이 설명하지 않고 넘어가겠습니다.

Code

그럼 이제 본격적으로 코드를 살펴보도록 합시다.

오늘 우리가 사용할 데이터셋은 Amazon Product Co-purchasing Network입니다. 이 오픈 데이터셋의 node는 상품이고, edge는 함께 구매된 적 있는지의 관계 여부를 나타내요. label은 제품 카테고리입니다.

이 label을 활용해 Node Classification을 하는 예측 task를 수행할 수 있어요.

 # 데이터를 training과 test로 분리

num_nodes = data.num_nodes
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:int(0.8 * num_nodes)] = 1  # 80%의 노드를 학습에 사용
test_mask = ~train_mask

가장 먼저, 임포트한 데이터셋을 split하고 shuffle하겠습니다.

데이터 셔플이 항상 필요한 것은 아니지만, 일반적으로 셔플링은 모델이 데이터의 순서에 의존하지 않도록 도와줍니다. 그 결과로 학습과 검증 과정에서 더 일반적인 성능을 보이게 할 수 있어요.

PyTorch Geometric (PyG)에는 데이터를 직접 셔플하는 내장 메서드가 없으므로 위와 같은 간단한 방식으로 그래프 노드를 분할하겠습니다.

 # GATv2 모델 정의

class GATv2Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GATv2Net, self).__init__()
        self.conv1 = GATv2Conv(in_channels, 128, heads=4)
        self.bn1 = torch.nn.BatchNorm1d(128 * 4)
        self.conv2 = GATv2Conv(128 * 4, out_channels, heads=1, concat=False)
        self.bn2 = torch.nn.BatchNorm1d(out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)

        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        return F.log_softmax(x, dim=1)

모델 Class를 정의할 때는 GATv2Conv를 사용했습니다. GATConv를 사용해도 문제는 없지만, 이번 튜토리얼에서는 더 간단하고 효율적인 매커니즘을 도입한 GATv2Conv 레이어로 학습을 진행하고 성능을 보도록 할게요.

Activation Function으로는 eLU를 선택했고, dropout, Learning Rate Scheduling, Early Stopping, Batch Normalization을 적용했습니다.

# 학습 루프

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Early Stopping
    if best_loss is None:
        best_loss = loss.item()
    elif best_loss > loss.item():
        best_loss = loss.item()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == patience:
            print("Early stopping!")
            break

Epoch 100, Loss: 0.38430055975914

optimizer는 AdamW를 사용했습니다. 간단한 모델을 설정한 다음 100 epoch를 돌려 보겠습니다. 간단한 모델에서도 얼마나 괜찮은 성능을 낼 지가 궁금하네요.

# 모델을 평가 모드로 설정
model.eval()
_, pred = model(data).max(dim=1)

# 테스트 데이터에서의 예측과 실제 라벨
pred_test = pred[test_mask].cpu().numpy()
y_test = data.y[test_mask].cpu().numpy()

# 정확도 계산
correct = pred[test_mask].eq(data.y[test_mask]).sum().item()
accuracy = correct / test_mask.sum().item()

# 정밀도, 재현율, F1 점수 계산
precision = precision_score(y_test, pred_test, average='macro')
recall = recall_score(y_test, pred_test, average='macro')
f1 = f1_score(y_test, pred_test, average='macro')

print(f"Test accuracy: {accuracy:.4f}")
print(f"Test precision: {precision:.4f}")
print(f"Test recall: {recall:.4f}")
print(f"Test F1 score: {f1:.4f}")

결과는 아래와 같습니다.

Test accuracy: 0.9055
Test precision: 0.8806
Test recall: 0.9137
Test F1 score: 0.8960

별도의 하이퍼 파라미터 튜닝이 없었고, head가 4개에 불과한데도 상당히 높은 성능 지표를 보여주네요.

새삼스럽게 Attention 매커니즘이 얼마나 강력한지, GAT의 성능을 실감할 수 있었습니다. 우리는 이번 튜토리얼 코드를 통해 Amazon 상품의 카테고리를 훌륭하게 분류해냈어요!

오늘 활용한 Amazon Product Co-purchasing Network 데이터셋의 크기는 아래와 같아요.

  • Data(x=[13752, 767], edge_index=[2, 491722], y=[13752])

그렇게 크지 않은 데이터셋임을 알 수 있죠. 하지만 현실 세계의, 그리고 현업 데이터는 실시간으로 적재되기에 굉장히 큰 Graph를 형성합니다. 이번 코드까지는 CPU에서 작업해도 무방한 수준이었지만, 만약 거대 그래프를 학습해야 하는 상황이 온다면 어떨까요?

🎫 GraphSAGE?

다음 튜토리얼에서는 대형 Graph에서도 상당한 효율성을 가진 모델.

미니 배치(mini-batch) 방식의 학습, Inductive Learning이 가능해 학습 시점에서 보지 못한 새로운 노드나 새로운 그래프에 대해서도 임베딩을 생성해내는 GraphSAGE에 대해 배워보기로 합시다!

Subscribe for daily recipes. No spam, just food.