트러블슈팅
데이터 입력 사이즈에 따라 positional encoding 차원의 값 변형 문제
J-Chris
2024. 11. 22. 11:55
#1 문제 상황
데이터 입력사이즈(W-Width, H-height, Channel)가 달라짐에 따라 positional encoding 차원의 값들도 달라질 것 입니다.
그러면 어떻게 맞춰주는게 좋을지 고민하게 되었습니다.
특히, ViT 입력으로 들어가기 위해서는 입력 이미지에 대해서 ViT의 patch_size 에 맞춰주는 작업이 필요합니다.
해당 부분을 아래 방법으로 해결하면 될 것 같다고 생각됩니다.
#1 해결 방법
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
출처 : Spatial.py 코드 일부
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
H, W = self.patch_embed.patch_shape
num_patches = self.patch_embed.num_patches
num_extra_tokens = 0
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens[:, 1:] # [1, 196, 768] # 197번째 요소를 제거해서 196으로 맞추기
print('pos_tokens shape : ', pos_tokens.shape) # [1, 197, 768] df
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
# new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
else:
state_dict['pos_embed'] = pos_embed_checkpoint[:, num_extra_tokens:]