[문제]

 

[내 코드]

N, K = map(int, input().split())
words = [input() for _ in range(N)]
basics = ['t', 'i', 'c', 'a', 'n']

if K < len(basics):
    print(0)
elif K == 26:
    print(len(words))
else:
    result = 0
    unlearned = list(set([w for word in words for w in word if w not in basics]))
    
    # unlearned 중에서 r개를 새로 배운다
    n = len(unlearned)
    r = K - len(basics)
    
    def dfs(idx, ls):
        global result
        if len(ls) == r:
            tmp = 0
            learned = basics + ls
            for word in words:
                if set(learned) > set(word):
                    tmp += 1
            result = max(result, tmp)
            return
        else:
            for i in range(idx, n):
                dfs(i+1, ls+[unlearned[i]])
    
    dfs(0, [])
    print(result)

- 최소로 알아야 되는 다섯 글자(basics)도 모르면 0을 출력하고, 모든 글자를 배우면 있는 단어의 갯수를 모두 출력한다

- 기본 글자를 제외한 나머지 글자들(unlearned)의 조합마다 최대로 배울 수 있는 단어의 개수를 출력하도록 했다

- 깊이우선탐색 알고리즘을 썼다

- 코드를 조금씩 바꿔 보았으나 계속 시간 초과가 떠서 정답을 보기로 했다.. 시간 제한 없었다면 맞았을까 궁금해지지만 ㅠㅠ

무수한 시간 초과의 흔적

 

[개선안]

이 문제는 비트 마스크를 활용해서 푸는 것이 이상적이다

 

글자의 조합을 2진수 비트로 나타낸 후, 정수로 변환할 수 있다

 

집합의 비교 또한 비트 연산이 가능하다

 

위 예시에서는 [a]가 [a, b]의 부분집합이므로 and 연산 결과가 1이다

 

한편, 참고한 정답은 dfs 대신에 itertools 라이브러리로 조합을 구했다

 

import itertools

# n, m 입력
n,m = map(int, sys.stdin.readline().split())

# words : 각 단어의 비트마스킹한 정수를 저장
words = [0] * n
ans = 0
for i in range(n):
    temp = input()
    # word 배열에 각 문자의 비트마스킹 저장
    for x in temp:
        words[i] |= (1 << (ord(x) - ord('a')))
        
# 만일 m이 5미만이면 필수 글자를 다 배울 수 없기 때문에 한 단어도 읽지 못한다
if m < 5:
    print(0)
else:
    # candidiate : 필수 글자를 제외한 알파벳
    # need : 필수 알파벳
    candidiate = ['b','d','e','f','g','h','j','k','l','m','o','p','q','r','s','u','v','w','x','y','z']
    need = ['a','c','t','i','n']
    for i in list(itertools.combinations(candidiate, m - 5)):
        each = 0
        res = 0
        # 각 조합에 대한 비트마스킹
        for j in need:
            each |= (1 << (ord(j) - ord('a')))
        for j in i:
            each |= (1 << (ord(j) - ord('a')))
            
        # 단어와 각 조합의 비교
        for j in words:
            if each & j == j:
                res += 1
                
        # 최대값 갱신
        if ans < res:
            ans = res
    print(ans)

 

부분으로 나눠서 살펴보겠다

# words : 각 단어의 비트마스킹한 정수를 저장
words = [0] * n
ans = 0
for i in range(n):
    temp = input()
    # word 배열에 각 문자의 비트마스킹 저장
    for x in temp:
        words[i] |= (1 << (ord(x) - ord('a')))

다음 세 단어가 주어졌을 때

antarctica
antahellotica
antacartica

 

처음 words는 [0, 0, 0] 으로 초기화된다

단어별로 iteration을 돌면서, 단어 안의 글자 별로 iteration을 또 돈다

(1) 차례에 온 글자의 유니코드 정수에서 알파벳 'a'의 유니코드 정수를 뺀다

즉 a부터 0, 1, 2, ... 가 되는 수로 바꾼 다음에

 

(2) 비트 1 을 그 수만큼 왼쪽으로 이동한다

따라서 a 는 0b1, b는 0b10, c는 0b100, ...

(3) word 리스트에서 그 단어에 해당되는 인덱스의 값을 OR 연산하여 업데이트 한다

antarctica 라고 한다면

초기값이 0 이므로, 0과 (1 << 0) 의 OR 연산 -> 1

1과 (1<<13) 의 OR 연산 -> 8193

8193과 (1<<19) 의 OR 연산 -> 532481

... 이런 식으로 antarctica를 나타낸 것은 

(오른쪽부터 a, b, c... 셀 수 있음) a, c, i, n, r, t 가 1로 표시된 결과

 

candidiate = ['b','d','e','f','g','h','j','k','l','m','o','p','q','r','s','u','v','w','x','y','z']
    need = ['a','c','t','i','n']
    for i in list(itertools.combinations(candidiate, m - 5)):
        each = 0
        res = 0
        # 각 조합에 대한 비트마스킹
        for j in need:
            each |= (1 << (ord(j) - ord('a')))
        for j in i:
            each |= (1 << (ord(j) - ord('a')))
            
        # 단어와 각 조합의 비교
        for j in words:
            if each & j == j:
                res += 1
                
        # 최대값 갱신
        if ans < res:
            ans = res

candidate 은 꼭 알아야 되는 다섯 글자(a, c, t, i, n)이 제외된 나머지 알파벳들이다

itertools의 combination으로 조합을 구해, 각 조합에 대해 비트마스킹을 수행한다

 

각 조합 i에 대해 each와 res를 초기화한다.

먼저 필수로 배우는 다섯 글자를 하나씩 뽑아서 or 연산을 한다

필수 글자가 표시된 비트가 연산된다

 

그 다음 조합 i 에서 글자를 하나씩 뽑아서 or 연산을 한다

조합의 글자가 'r', 'k'라고 했을 때 'r', 'k' 자리에 1이 표시되었음

 

이렇게 만든 each와 단어 리스트의 단어를 하나씩 비교하는데, (단어 리스트는 앞서 글자의 비트로 표현해두었다)

이때는 AND 연산을 수행한 값이 그 단어일 때, 배운 글자 안에 단어의 글자가 포함된다는 의미이므로 res 에 1점을 더한다

 

... 그렇게 조합마다 값을 구해 최대값을 갱신한다

 

 

 

*참고

파이썬의 ord 함수: 문자에 해당하는 유니코드 정수(10진수)를 반환

hex : 10진수를 16진수로 변경

result1 = ord('a')
result2 = ord('ㄱ')
result3 = hex(ord('b'))
print(f"ord('a') : {result1}")
print(f"ord('ㄱ') : {result2}")
print(f"hex(ord('b')) : {result3}\n")

 

비트 연산자

연산자 설명
a & b a와 b의 비트를 AND 연산
a | b a와 b의 비트를 OR 연산
a ^ b a와 b의 비트를 XOR 연산
~a a의 비트를 뒤집음
a << b a의 비트를 b번 왼쪽으로 이동
a >> b a의 비트를 b번 오른쪽으로 이동

 

복사했습니다!