人們可以從構建結果向上逐位底端。從最低位開始,嘗試0和1作爲a
的最低位,並查看sum-xor的最低位是否與S的相應位匹配。然後嘗試下一個最低位,傳播上一步中的任何進位。
按照這個算法,a
的每一位可能有0,1或2個選擇,所以在最壞的情況下,我們可能需要探索不同的分支並選擇一個給出最小結果的分支。爲了避免指數行爲,我們緩存先前看到的進位結果。這產生了O(kn)的最壞情況複雜度,其中k是結果中的最大比特數,並且n是給定輸入列表長度爲n的進位的最大值。
下面是實現這一些Python代碼:
max_shift = 80
def xor_sum0(xs, S, shift, carry, cache, sums):
if shift >= max_shift:
return 1e100 if carry else 0
key = shift, carry
if key in cache:
return cache[key]
best = 1e100
for i in xrange(2):
ss = sums[i][shift] + carry
if ss & 1 == (S >> shift) & 1:
best = min(best, i + 2 * xor_sum0(xs, S, shift + 1, ss >> 1, cache, sums))
cache[key] = best
return cache[key]
def xor_sum(xs, S):
sums = [
[sum(((x >> sh)^i) & 1 for x in xs) for sh in xrange(max_shift)]
for i in xrange(2)]
return xor_sum0(xs, S, 0, 0, dict(), sums)
在這種情況下也沒有解決,代碼返回一個大的(> = 1e100)浮點數。
這裏是一個測試,在你給出的範圍內選取隨機值,隨機挑選a
並計算S,然後求解。請注意,由於a
的值不總是唯一的,因此有時代碼會找到比用於計算S的代碼更小的a
。
import random
xs = [random.randrange(0, 1 << 61) for _ in xrange(random.randrange(10 ** 5))]
a_original = random.randrange(1 << 61)
S = sum(x^a_original for x in xs)
print S
print xs
a = xor_sum(xs, S)
assert a < 1e100
print 'a:', a
print 'original a:', a_original
assert a <= a_original
print 'S', S
print 'SUM', sum(x^a for x in xs)
assert sum(x^a for x in xs) == S