Here is my solution but I still fail two test cases, does anyone have any recommendation to fix it, many thanks!

def countTriplets(arr, r):
dict={}
for i in arr:
if not dict.get(i):
dict[i]=[i*r,1]
else:
dict[i][1]+=1
temp=list(dict.keys())
print (temp)
print (dict)
result=0
for i in temp:
if r==1:
result+=dict.get(i)[1]*(dict.get(i)[1]-1)*(dict.get(i)[1]-2)/6
elif dict.get(dict.get(i)[0])!=None and dict.get(dict.get(dict.get(i)[0])[0])!=None:
first=i
second=dict.get(first)[0]
third=dict.get(second)[0]
result+=dict.get(first)[1]*dict.get(second)[1]*dict.get(third)[1]
return int(result)

## Count Triplets

