Junarata-tehtävä 2

Tiedosto fenwickTree.py

class FenwickTree:

    def __init__(self, n):
        self.n = n+1
        self.arr = [0]*self.n

    def incr(self, ind, val):
        """Increment the value at index ind by val"""
        if ind<0: raise Exception("Index out of bounds")
        curr = ind+1
        while curr<self.n:
            self.arr[curr] += val
            curr += curr&(-curr)

    def getSum(self, upto):
        """Get the prefix sum up to index upto"""
        s = 0
        ind = upto+1
        while ind>0:
            s += self.arr[ind]
            ind -= ind&(-ind)
        return s

Tiedosto junarata6.py

from fenwickTree import FenwickTree

class IntvalSet:
    """Set where the included natural numbers
        are given as intervals [a, b]"""
    def __init__(self):
        self.intvals = []

    def has(self, x):
        for (a,b) in self.intvals:
            if a<=x and x<=b: return True
        return False

    def add(self, a, b):
        self.intvals.append((a, b))

    def getAll(self, lowerBdd=0):
        """generates all numbers in the set.
            If lowerBdd is given,
            generates only numbers >= lowerBdd"""
        for (a,b) in self.intvals:
            for x in range(max(a, lowerBdd), b+1):
                yield x

    def __str__(self):
        return "[]" if len(self.intvals)==0 else " U ".join(["[%d, %d]" %x for x in self.intvals])


def getData(num):
    ret = []
    with open("junar"+str(num)+".in") as f:
        ret = f.read().split("\n")
    return [int(x) for x in ret[1:] if len(x)>0]

def getPoss(p):
    """Get list of positions of elements in the
        permutation p of [1..n]
        e.g. getPoss([2,5,3,1,4]) = [None, 3, 0, 2, 4, 1]
    """
    ret = [None]*(len(p)+1)
    for i,x in enumerate(p):
        ret[x] = i
    return ret

"""
Let's denote types of orderings of k-1, k and k+1
(or k and k+1; k-1 and k in cases k=1; k=n  like this:

123: 0
132: 1
213: 2
231: 3
312: 4
321: 5
_12: 6
_21: 7
12_: 8
21_: 9
"""
def getTyyppi(a, b, c):
    if a<b:
        if b<c: return 0
        if a<c: return 1
        return 3
    else: #b<a
        if b>c: return 5
        if a<c: return 2
        return 4


def getNumOfRounds(poss):
    """How many rounds does it take to gather the numbers.
        @param poss: index-positions of the permutation
    """
    curr = 1
    prevInd = poss[1]
    r = 1
    for curr in range(1, len(poss)):
        currInd = poss[curr]
        if currInd<prevInd: r += 1
        prevInd = currInd
    return r


def countScore2Pairs(arr):
    n = len(arr)
    poss = getPoss(arr)
    toConsider = []
    t = None
    tempSet = None
    for i, x in enumerate(arr):
        if x==1: t = 6 if poss[x+1] > i else 7
        elif x==n: t = 8 if poss[x-1] < i else 9
        else: t = getTyyppi(poss[x-1], i, poss[x+1])

        tempSet = IntvalSet()
        
        if t==2:
            tempSet.add(poss[x-1], poss[x+1]-1)
        elif t==5:
            tempSet.add(0, poss[x+1])
            tempSet.add(poss[x-1], n-1)
        elif t==1:
            tempSet.add(poss[x-1]+1, poss[x+1])
        elif t==7:
            tempSet.add(0, poss[x+1])
        elif t==9:
            tempSet.add(poss[x-1], n-1)
        
        toConsider.append(tempSet)


    """
    mat = [[1 if toConsider[i].has(j) else 0 for j in range(n)]
           for i in range(n)]

    print ("To Consider Matrix:")
    #print (mat)
    for row in mat: print row
    """

    tree = FenwickTree(n)

    events = [[] for _ in range(n+1)] #grouped by time
    #print ("Events len is "+str(len(events)))
    for i, c in enumerate(toConsider):
        for (j1, j2) in c.intvals:
            #print ("Pushing events for interval [%d, %d]" %(j1, j2))
            events[j1].append([j1, i, 1])
            events[j2+1].append([j2+1, i, -1])
    
    numPairs = 0
    for k, evs in enumerate(events[:-1]):
        for (j, i, d) in evs:
            tree.incr(i, d)
        for (j1, j2) in toConsider[k].intvals:
            numPairs += tree.getSum(j2) - tree.getSum(j1-1)
        #subtract the adjacent ones
        x = arr[k]
        if (x>1 and toConsider[k].has(poss[x-1])
            and toConsider[poss[x-1]].has(k) ): numPairs -= 1
        if (x<n and toConsider[k].has(poss[x+1])
            and toConsider[poss[x+1]].has(k) ): numPairs -= 1
    
    return numPairs/2


def junarata(arr):
    n = len(arr)
    poss = getPoss(arr)
    origRounds = getNumOfRounds(poss)
    numPairs = countScore2Pairs(arr)
    return "%d %d %d" %(n, origRounds-2, numPairs)



import random

#a = range(1, 10); random.shuffle(a); print(junarata(a));

#Solve putkaposti
for i in range(1, 10):
    print(junarata(getData(i)))



#Test execution time
"""
import time
startT = time.time()
print (junarata(getData(9)))
print ("Time: "+str(time.time()-startT))
"""

Junarata-tehtävä

class IntvalSet:
    """Set where the included natural numbers
        are given as intervals [a, b]"""
    def __init__(self):
        self.intvals = []

    def has(self, x):
        for (a,b) in self.intvals:
            if a<=x and x<=b: return True
        return False

    def add(self, a, b):
        self.intvals.append((a, b))

    def getAll(self, lowerBdd=0):
        """generates all numbers in the set.
            If lowerBdd is given,
            generates only numbers >= lowerBdd"""
        for (a,b) in self.intvals:
            for x in range(max(a, lowerBdd), b+1):
                yield x

    def __str__(self):
        return ", ".join([str(x) for x in self.getAll()])
    


def getData(num):
    ret = []
    with open("junar"+str(num)+".in") as f:
        ret = f.read().split("\n")
    return [int(x) for x in ret[1:] if len(x)>0]


def getPoss(p):
    """Get list of positions of elements in the
        permutation p of [1..n]
        e.g. getPoss([2,5,3,1,4]) = [None, 3, 0, 2, 4, 1]
    """
    ret = [None]*(len(p)+1)
    for i,x in enumerate(p):
        ret[x] = i
    return ret

"""
Let's denote types of orderings of k-1, k and k+1
(or k and k+1; k-1 and k in cases k=1; k=n  like this:

123: 0
132: 1
213: 2
231: 3
312: 4
321: 5
_12: 6
_21: 7
12_: 8
21_: 9
"""
def getTyyppi(a, b, c):
    if a<b:
        if b<c: return 0
        if a<c: return 1
        return 3
    else: #b<a
        if b>c: return 5
        if a<c: return 2
        return 4


def getNumOfRounds(poss):
    """How many rounds does it take to gather the numbers.
        @param poss: index-positions of the permutation
    """
    curr = 1
    prevInd = poss[1]
    r = 1
    for curr in range(1, len(poss)):
        currInd = poss[curr]
        if currInd<prevInd: r += 1
        prevInd = currInd
    return r

"""
As the above algorithm shows, the number of rounds for a permutation is
1 + #{k | poss[k+1]<poss[k]}
A swap can decrease the rounds by at most 2.
This can be seen from the following:
Let's consider for each k in [1..n] the positions that into which it is
moved, decrease the rounds. This can only ever be at most 1.
But a swap moves two, so if the other one also is a decreasing move,
we get 2, except in the case when we swap some k and k+1,
since that only decreases the rounds by 1.

(Look at the types, and consider for each how it decreases rounds.)

Let's hope the answer will always be 'score 2' (decreases rounds by 2) as
the existence of such swap is highly probably when we have a large random
permutation.
"""

def countScore2Pairs(arr):
    n = len(arr)
    poss = getPoss(arr)
    toConsider = []
    t = None
    tempSet = None
    for i, x in enumerate(arr):
        if x==1: t = 6 if poss[x+1] > i else 7
        elif x==n: t = 8 if poss[x-1] < i else 9
        else: t = getTyyppi(poss[x-1], i, poss[x+1])

        tempSet = IntvalSet()
        
        if t==2:
            tempSet.add(poss[x-1], poss[x+1]-1)
        elif t==5:
            tempSet.add(0, poss[x+1])
            tempSet.add(poss[x-1], n-1)
        elif t==1:
            tempSet.add(poss[x-1]+1, poss[x+1])
        elif t==7:
            tempSet.add(0, poss[x+1])
        elif t==9:
            tempSet.add(poss[x-1], n-1)
        
        toConsider.append(tempSet)
    
    numPairs = 0
    for i, x in enumerate(arr):
        #if i%1000==0: print ("i="+str(i))
        for j in toConsider[i].getAll(i+1):
            if toConsider[j].has(i) and abs(x-arr[j])>1:
                numPairs += 1
    return numPairs


def junarata(arr):
    n = len(arr)
    poss = getPoss(arr)
    origRounds = getNumOfRounds(poss)
    numPairs = countScore2Pairs(arr)
    return "%d %d %d" %(n, origRounds-2, numPairs)


##---------------------------------------------------
## Random tests

import random
def randomJuna(n):
    a = range(1, n+1)
    random.shuffle(a)
    return [int(x) for x in junarata(a).split()]

from collections import Counter
def genJunaData(n, simuN):
    d = [randomJuna(n) for _ in range(simuN)]
    #print ([x[2] for x in d])
    c1 = Counter([x[1] for x in d])
    c2 = Counter([x[2] for x in d])
    c = c2
    return [ [k, c[k]] for k in sorted(c.keys()) ]
    
##-------------------------------------------------------


#Solve putkaposti
for i in range(1, 10):
    print(junarata(getData(i)))


""" Test execution time
import time
startT = time.time()
print (junarata(getData(9)))
print ("Time: "+str(time.time()-startT))
"""