트러블슈팅

데이터 입력 사이즈에 따라 positional encoding 차원의 값 변형 문제

J-Chris 2024. 11. 22. 11:55

#1 문제 상황

데이터 입력사이즈(W-Width, H-height, Channel)가 달라짐에 따라 positional encoding 차원의 값들도 달라질 것 입니다. 

그러면 어떻게 맞춰주는게 좋을지 고민하게 되었습니다. 

특히, ViT 입력으로 들어가기 위해서는 입력 이미지에 대해서 ViT의 patch_size 에 맞춰주는 작업이 필요합니다.

 

해당 부분을 아래 방법으로 해결하면 될 것 같다고 생각됩니다.

 

#1 해결 방법

 

pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
                    pos_tokens = torch.nn.functional.interpolate(
                        pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
                    new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)

 

 

출처 : Spatial.py 코드 일부

 

if 'pos_embed' in state_dict:
                pos_embed_checkpoint = state_dict['pos_embed']
                embedding_size = pos_embed_checkpoint.shape[-1]
                H, W = self.patch_embed.patch_shape
                num_patches = self.patch_embed.num_patches
                num_extra_tokens = 0
                # height (== width) for the checkpoint position embedding
                orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
                # height (== width) for the new position embedding
                new_size = int(num_patches ** 0.5)
                # class_token and dist_token are kept unchanged
                if orig_size != new_size:
                    pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                    pos_tokens = pos_tokens[:, 1:]  # [1, 196, 768] # 197번째 요소를 제거해서 196으로 맞추기
                    print('pos_tokens shape : ', pos_tokens.shape) # [1, 197, 768] df
                    pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
                    pos_tokens = torch.nn.functional.interpolate(
                        pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
                    new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                    # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                    state_dict['pos_embed'] = new_pos_embed
                else:
                    state_dict['pos_embed'] = pos_embed_checkpoint[:, num_extra_tokens:]