import java.io.*; import java.util.*; public class Solution { public static void main(String[] args) { Scanner sc = new Scanner(new BufferedInputStream(System.in)); int m = sc.nextInt(); int n = sc.nextInt(); Integer[] yi = new Integer[m-1]; Integer[] xi = new Integer[n-1]; for(int j=0;j=n || yi[ny-1]>xi[nx-1])) { c= (c + ((long)nx)); ny++; } else if(nx=m || xi[nx-1]>=yi[ny-1])) { c= (c + ((long)ny)); nx++; } } if(n==m){ System.out.println(c - Math.abs(((n-1)*(m-1)))); }else{ System.out.println(c - Math.abs(((n-1)*(m-1)))); } } }