/* Program to accompany "Balls and urns Part 2: Multi-colored balls"
by Rick Wicklin, published 2OCT2015 on The DO Loop blog
http://blogs.sas.com/content/iml/2015/10/02/balls-and-urns2.html
*/
proc iml;
/* There are K[i] balls of color i in an urn.
Draw N balls without replacement and report the number of
each color.
N[1] = sample size
N[2] = number of replications (optional. Default is 1)
K = vector that gives the number balls of each color.
The total number of balls is sum(K).
The counts of the balls follow a multivariate hypergeometric distribution.
*/
start RandMVHyper(N, _K);
if nrow(N)*ncol(N)>1 then nRep = N[2];
else nRep = 1;
K = rowvec(_K); /* K[i] is number of items for category i */
nDraw = j(nRep, 1, N[1]); /* run nRep sims at once */
ItemsLeft = j(nRep, 1, sum(K));
x = j(nRep, ncol(K), 0); /* each row is draw from MV hyper */
h = j(nRep, 1, 0); /* vec for hypergeometric values */
do i = 1 to ncol(K)-1;
Kvec = j(nRep, 1, K[i]);
idx0 = loc(nDraw=0);
if ncol(idx0)=0 then do; /* usual case */
call randgen(h, "Hyper", ItemsLeft, Kvec, nDraw);
x[,i] = h;
ItemsLeft = ItemsLeft - K[i]; /* update parameters */
nDraw = nDraw - h;
end;
else do;
/* for some replicate, all balls have been drawn but there are
still more colors to draw (which will be zero). This can happen
when the last few colors are sparesely represented, or in small
samples such as drawing 5 balls from K = {5 2 2} */
x[idx0,i] = 0;
idx1 = loc(nDraw>0);
if ncol(idx)>0 then do;
hh = idx1`; /* allocate */
call randgen(hh, "Hyper", ItemsLeft[idx1], Kvec[idx1], nDraw[idx1]);
x[idx1,i] = hh;
ItemsLeft = ItemsLeft - K[i]; /* update parameters */
nDraw[idx1] = nDraw[idx1] - hh;
end;
end;
end;
x[,ncol(K)] = nDraw;
return (x);
finish;
call randseed(1234);
/* TEST: nDraws nRep K1 K2 K3 */
y = RandMVHyper({100 1000}, {100 40 60});
print (y[1:5,])[c={"black" "white" "red"} L="Draws w/o Replacement"];
mean = mean(y);
print mean;
/* TEST edge conditions K1 K2 K3 */
y = RandMVHyper({5 100}, { 5 2 2});
idx = loc(y[,2]=0);
print (y[idx,])[c={"black" "white" "red"} L="Draws w/o Replacement"];
mean = mean(y);
print mean;