본문 바로가기

연구/pytorch

[Transformer] torch.nn.MultiheadAttention 모듈의 mask 인자 개념

반응형

이 글은 다소 오류가 있을 수 있으니 틀린 내용을 발견하셨다면 언제라도 피드백 환영합니다.

 

nn.MultiheadAttention 의 forward 인자는 다음과 같다.

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)

이 중에서 mask 가 2개 있다는 점이 코드 구현중에 헷갈리는 부분이었다.

 

mask는 각각 key_padding_mask, attn_mask 가 있고 이 둘의 차이는 무엇일까?

 

Transformer는 본래 이미지처리 분야가 아닌 자연어처리 분야에서 고안된 모듈이다. 하지만 자연어처리에서 입력은 이미지처럼 입력 사이즈가 고정되지 않고, 문장마다 길이가 제각각이다. 하지만 Transformer 모델에 문장들이 batch 단위로 입력이 될텐데, batch안에 있는 문장들의 길이(Sequence Length)가 다르면 모델에 입력조차 할 수 없는 문제가 발생한다.

따라서 입력 문장의 길이를 맞춰주기 위해 짧은 문장의 뒤에는 padding을 넣어서 batch 안에 있는 입력들의 Sequence Length 를 통일시킨다.

예를 들면 아래와 같다.

나는 학교에 간다.
내가 학교에 가면 공부도 하고 운동도 할 거야.

여기서 윗 문장과 아랫문장은 sequence length가 각각 3, 8로 차이가 나니까 다음과 같이 문장의 길이를 맞춰준다.

나는 학교에 간다. <0> <0> <0> <0> <0>
내가 학교에 가면 공부도 하고 운동도 할 거야.

 

이렇게 하면 문장들을 transformer에 입력할 수 있는 상태가 되지만, 또 문제가 발생한다. 뒤에 패딩토큰 <0> 들이 실제로는 아무 의미가 없는 토큰이라는 점이다. 따라서 Transformer 가 attention 을 계산하는 과정에서 패팅토큰들이 실제 의미있는 단어들에게 관여하지 못하게 하기 위해 masking 작업이 필요하다.

그게 바로 key_padding_mask 이다.

하지만 나같은 이미지처리 혹은 컴퓨터 비전 분야에서 transformer 를 사용하는 사람이라면 아마 이 개념을 쓸 일이 별로 없지 않을까 싶다. 왜냐하면 모델에 입력되는 이미지들은 사이즈가 보통 고정되어 있기 때문이다.(224x224, 256x256, 384x384, ... )

 

그렇다면 attn_mask 는 무엇일까?

패딩토큰들 말고, 실제 의미가 있는 토큰들 간의 attention을 계산할때 어떤 query의 경우 특정한 key들로부터 영향을 계산하고 싶지 않은 경우가 있다. 이럴때 query 가 key로부터 영향을 받지 않도록 masking하는 역할을 하는것이 attn_mask 이다.

예를 들어 FastMETRO 논문에서는 Vertex-Vertex 간 attention을 계산할때 거리가 1인 vertex 로부터만 attention 계산을 의도하고, 거리가 먼 vertex 로부터는 attention 을 계산하지 않도록 설계를 했다.

 

참고문헌:

1. FastMETRO: Cross-Attention of disentangled Modalities for 3D Human Mesh Recovery with Transformers (https://arxiv.org/abs/2207.13820)

2. https://stackoverflow.com/questions/62629644/what-the-difference-between-att-mask-and-key-padding-mask-in-multiheadattnetion

반응형