## 2011年11月5日 星期六

### [ Algorithm ] Viterbi algorithm : Finding the most likely sequence of hidden states

* First, both the observed events and hidden events must be in a sequence. This sequence often corresponds to time.
* Second, these two sequences need to be aligned, and an instance of an observed event needs to correspond to exactly one instance of a hidden event.
* Third, computing the most likely hidden sequence up to a certain point t must depend only on the observed event at point t, and the most likely sequence at point t − 1.

The Viterbi algorithm was conceived by Andrew Viterbi in 1967 as a decoding algorithm for convolutional codes over noisy digital communication links. For more details on the history of the development of the algorithm see David Forney's article.[1] The algorithm has found universal application in decoding the convolutional codes used in both CDMA and GSM digital cellular, dial-up modems, satellite, deep-space communications, and 802.11 wireless LANs. It is now also commonly used inspeech recognition, keyword spotting, computational linguistics, and bioinformatics. For example, in speech-to-text (speech recognition), the acoustic signal is treated as the observed sequence of events, and a string of text is considered to be the "hidden cause" of the acoustic signal. The Viterbi algorithm finds the most likely string of text given the acoustic signal.

P(Rainy|Start) = 0.6 ; P(Sunny|Start) = 0.4 # 0.4 + 0.6 = 1
P(Rainy|Rainy) = 0.7 ; P(Sunny|Rainy) = 0.3 # 0.7 + 0.3 = 1
P(Sunny|Sunny) = 0.6 ; P(Rainy|Sunny) = 0.4 # 0.6 + 0.4 = 1

P(Walk|Rainy) = 0.1 ; P(Shop|Rainy) = 0.4 ; P(Clean|Rainy) = 0.5 # 0.1 + 0.4 + 0.5 = 1
P(Walk|Sunny) = 0.6 ; P(Shop|Sunny) = 0.3 ; P(Clean|Sunny) = 0.1 # 0.6 + 0.3 + 0.1 = 1

Coding 實作 :

- Viterbi.java :
1. package alg.others;
2.
3. import java.util.HashMap;
4. import java.util.Stack;
5.
6. public class Viterbi {
7.     enum Weather {Start,Rainy, Sunny}
8.     enum Action {Walk,Shop,Clean}
9.
10.     public static float max(float a, float b){return a>=b?a:b;}
11.     public static void main(String args[])
12.     {
13.         /*Prepare Table A*/
14.         HashMap> tableA = new HashMap>();
15.         HashMap rainyProb = new HashMap();
16.         HashMap sunnyProb = new HashMap();
17.         HashMap startProb = new HashMap();
18.         startProb.put(Weather.Rainy, (float)0.6); /*Start to Rainy is 0.6*/
19.         startProb.put(Weather.Sunny, (float)0.4); /*Start to Sunny is 0.6*/
20.         rainyProb.put(Weather.Rainy, (float)0.7); /*Rainy to Rainy is 0.7*/
21.         rainyProb.put(Weather.Sunny, (float)0.3); /*Rainy to Sunny is 0.3*/
22.         sunnyProb.put(Weather.Rainy, (float)0.6); /*Sunny to Rainy is 0.4*/
23.         sunnyProb.put(Weather.Sunny, (float)0.4); /*Sunny to Sunny is 0.6*/
24.         tableA.put(Weather.Rainy, rainyProb);
25.         tableA.put(Weather.Sunny, sunnyProb);
26.         tableA.put(Weather.Start, startProb);
27.
28.         /*Prepare Table B*/
29.         HashMap> tableB = new HashMap>();
30.         HashMap rainyActProb = new HashMap();
31.         HashMap sunnyActProb = new HashMap();
32.         rainyActProb.put(Action.Walk, (float)0.1);
33.         rainyActProb.put(Action.Shop, (float)0.4);
34.         rainyActProb.put(Action.Clean, (float)0.5);
35.         sunnyActProb.put(Action.Walk, (float)0.6);
36.         sunnyActProb.put(Action.Shop, (float)0.3);
37.         sunnyActProb.put(Action.Clean, (float)0.1);
38.         tableB.put(Weather.Rainy, rainyActProb);
39.         tableB.put(Weather.Sunny, sunnyActProb);
40.
41.         /*Prepare buffer to store the current status after processing.*/
42.         HashMap curStatus = new HashMap();
43.         /*Prepare stack to store the weather of each day.*/
44.         /*Stack sample code : http://www.easywayserver.com/blog/java-stack-example/*/
45.         /*Queue sample code : http://www.easywayserver.com/blog/java-queue-example/*/
46.         HashMap prevWeather = null;
47.         Stack> wStack = new Stack>();
48.         /*Observed Actions*/
49.         Action actions[] = {Action.Walk, Action.Walk, Action.Shop, Action.Clean};
50.
51.         for(int i=0; i
52.         {
53.             if(wStack.size()==0)
54.             {
55.                 /*Initialization*/
56.                 float StoRainy = tableA.get(Weather.Start).get(Weather.Rainy) * tableB.get(Weather.Rainy).get(actions[i]);
57.                 float StoSunny = tableA.get(Weather.Start).get(Weather.Sunny) * tableB.get(Weather.Sunny).get(actions[i]);
58.                 curStatus.put(Weather.Rainy, StoRainy);
59.                 curStatus.put(Weather.Sunny, StoSunny);
60.                 System.out.printf("\t[Initialization] Start to Rainy is %4.3f...\n", StoRainy);
61.                 System.out.printf("\t[Initialization] Start to Sunny is %4.3f...\n", StoSunny);
62.                 prevWeather = new HashMap>();
63.                 prevWeather.put(Weather.Rainy, Weather.Start);
64.                 prevWeather.put(Weather.Sunny, Weather.Start);
65.                 if(StoRainy>=StoSunny)
66.                 {
67.                     /*Supposed to be Rainy*/
68.                     wStack.push(prevWeather);
69.                     System.out.printf("\t[Initialization] Day1 choose Rainy! (%4.3f >= %4.3f)\n", StoRainy, StoSunny);
70.                 }
71.                 else
72.                 {
73.                     /*Supposed to be Sunny*/
74.                     wStack.push(prevWeather);
75.                     System.out.printf("\t[Initialization] Day1 choose Sunny! (%4.3f, > %4.3f)\n", StoSunny, StoRainy);
76.                 }
77.             }
78.             else
79.             {
80.                 /*Recursion*/
81.                 prevWeather = new HashMap();
82.                 System.out.printf("\t[Recursion] CurRainyProb=%4.4f...\n", curStatus.get(Weather.Rainy));
83.                 System.out.printf("\t[Recursion] CurSunnyProb=%4.4f...\n", curStatus.get(Weather.Sunny));
84.                 System.out.printf("\t[Recursion] Today do %s...\n", actions[i].name());
85.                 float RtoRainy = curStatus.get(Weather.Rainy) *
86.                                  tableA.get(Weather.Rainy).get(Weather.Rainy) *
87.                                  tableB.get(Weather.Rainy).get(actions[i]);
88.                 System.out.printf("\t[Recursion] Rainy to Rainy is %4.4f...\n", RtoRainy);
89.                 float RtoSunny = curStatus.get(Weather.Rainy) *
90.                                  tableA.get(Weather.Rainy).get(Weather.Sunny) *
91.                                  tableB.get(Weather.Sunny).get(actions[i]);
92.                 System.out.printf("\t[Recursion] Rainy to Sunny is %4.4f...\n", RtoSunny);
93.                 float StoRainy = curStatus.get(Weather.Sunny) *
94.                                  tableA.get(Weather.Sunny).get(Weather.Rainy) *
95.                                  tableB.get(Weather.Rainy).get(actions[i]);
96.                 System.out.printf("\t[Recursion] Sunny to Rainy is %4.4f...\n", StoRainy);
97.                 float StoSunny = curStatus.get(Weather.Sunny) *
98.                                  tableA.get(Weather.Sunny).get(Weather.Sunny) *
99.                                  tableB.get(Weather.Sunny).get(actions[i]);
100.                 System.out.printf("\t[Recursion] Sunny to Sunny is %4.4f...\n", StoSunny);
101.
102.                 if(RtoSunny>=StoSunny) prevWeather.put(Weather.Sunny, Weather.Rainy);
103.                 else prevWeather.put(Weather.Sunny, Weather.Sunny);
104.                 if(RtoRainy>=StoRainy) prevWeather.put(Weather.Rainy, Weather.Rainy);
105.                 else prevWeather.put(Weather.Rainy, Weather.Sunny);
106.                 curStatus.put(Weather.Rainy, max(RtoRainy,StoRainy));
107.                 curStatus.put(Weather.Sunny, max(RtoSunny,StoSunny));
108.                 wStack.push(prevWeather);
109.             }
110.         }
111.
112.         Stack tStack = new Stack();
113.         Weather wer = curStatus.get(Weather.Rainy)>curStatus.get(Weather.Sunny)?Weather.Rainy:Weather.Sunny;
114.         tStack.push(wer);
115.         while(!wStack.isEmpty())
116.         {
117.             prevWeather = wStack.pop();
118.             tStack.push(prevWeather.get(tStack.peek()));
119.         }
120.         int i=0;
121.         while(!tStack.isEmpty())
122.         {
123.             System.out.printf("\t[Result] Day%d is %s...\n", i, tStack.pop().name());
124.             i++;
125.         }
126.     }
127. }

[Initialization] Start to Rainy is 0.060...
[Initialization] Start to Sunny is 0.240...
[Initialization] Day1 choose Sunny! (0.240, > 0.060)
[Recursion] CurRainyProb=0.0600...
[Recursion] CurSunnyProb=0.2400...
[Recursion] Today do Walk...
[Recursion] Rainy to Rainy is 0.0042...
[Recursion] Rainy to Sunny is 0.0108...
[Recursion] Sunny to Rainy is 0.0144...
[Recursion] Sunny to Sunny is 0.0576...
[Recursion] CurRainyProb=0.0144...
[Recursion] CurSunnyProb=0.0576...
[Recursion] Today do Shop...
[Recursion] Rainy to Rainy is 0.0040...
[Recursion] Rainy to Sunny is 0.0013...
[Recursion] Sunny to Rainy is 0.0138...
[Recursion] Sunny to Sunny is 0.0069...
[Recursion] CurRainyProb=0.0138...
[Recursion] CurSunnyProb=0.0069...
[Recursion] Today do Clean...
[Recursion] Rainy to Rainy is 0.0048...
[Recursion] Rainy to Sunny is 0.0004...
[Recursion] Sunny to Rainy is 0.0021...
[Recursion] Sunny to Sunny is 0.0003...
[Result] Day0 is Start...
[Result] Day1 is Sunny...
[Result] Day2 is Sunny...
[Result] Day3 is Rainy...
[Result] Day4 is Rainy...

- ViterbiAlg.groovy
1. class ViterbiAlg {
2.     def LZERO = -1*Math.pow(1010)
3.     def LSMALL = -0.5*Math.pow(1010)
4.     def minLogExp = -Math.log(-LZERO)
5.     def isDebug = false
6.     def shift=0
7.
8.     def A = []                          // State MM
9.     def B = []                          // Observer Matrix
10.     def π = []                          // Default/Initial Prob.
11.     def SS = 0                          // State Size
12.     def OS = 0                          // Observation Size
13.
14.     void dprintf(String fmt, ...Os)
15.     {
16.         if(isDebug) printf(fmt, Os)
17.     }
18.
19.     /**
20.      * Input: x = log(x'); y = log(y')
21.      * Return: log(x'+y')
22.      * @param x
23.      * @param y
24.      * @return
25.      */
26.     double logAdd(double x, double y)
27.     {
28.         double temp, diff, z;
29.         if (x < y)
30.         {
31.             temp = x; x = y; y = temp;
32.         }
33.         diff = y-x; // notice that diff <= 0
34.         if (diff < minLogExp)   // if y' is far smaller than x'
35.             return (x < LSMALL) ? LZERO : x;
36.         else
37.         {
38.             z = Math.exp(diff);
39.             return x + Math.log(1.0 + z);
40.         }
41.     }
42.
43.     /**
44.      * Input: d1, d2, ... dn
45.      * Return: Log(d1)+Log(d2)+...+Log(dn)
46.      * @param ds
47.      * @return
48.      */
49.     double logSum(Double ...ds)
50.     {
51.         double lSum=0
52.         ds.each {
53.             lSum+=Math.log(it)
54.             //lSum*=it
55.         }
56.         return lSum
57.     }
58.
59.     void setPM(def A, def B, def π){setA(A); setB(B); this.π = π;}
60.
61.     void setA(def A)
62.     {
63.         this.A = A
64.         SS = A.size()
65.     }
66.
67.     void setB(def B)
68.     {
69.         this.B = B
70.         OS = B[0].size()
71.     }
72.
73.     double backward(List olist)
74.     {
75.         def β = []
76.         def tp = []
77.         if(olist.size()==0return 0
78.         else if(olist.size()==1)
79.         {
80.             def o = olist[0]
81.             return π[o]*B[o][olist[0]]
82.         }
83.
84.         // Initialize
85.         SS.times{tp << 1}
86.         β[olist.size()-1] = tp
87.
88.         if(olist.size()>1)
89.         {
90.             for(int i=olist.size()-2; i>=0; i--)
91.             {
92.                 tp = []
93.                 SS.times{ ps->
94.                     double p = 0
95.                     SS.times{ cs->
96.                         p += (double)A[ps][cs] * B[cs][olist[i+1]] * β[i+1][cs]
97.                     }
98.                     tp << p
99.                 }
100.                 β[i]=tp
101.             }
102.         }
103.
104.         double p = 0;
105.         SS.times{
106.             p += (double)π[it] * B[it][olist[0]] * β[0][it];
107.         }
108.         return p;
109.     }
110.
111.     double forward(List olist)
112.     {
113.         def α = []
114.         def tp = []
115.         if(olist.size()==0return 0
116.         else if(olist.size()==1)
117.         {
118.             def o = olist[0]
119.             return π[o]*B[o][olist[0]]
120.         }
121.
122.         // Initialize
123.         //printf "\t[Test] o=%d\n", olist[0]
124.         SS.times{it->
125.             tp << (double)π[it]*B[it][olist[0]]
126.         }
127.         α << tp
128.         if(olist.size()>1)
129.         {
130.             def tolist = olist[1..-1]
131.             tolist.each{ o->
132.                 tp = []
133.                 SS.times{ps->
134.                     double prob=0
135.                     SS.times{ cs->
136.                         prob+=(double)α[-1][ps]*A[ps][cs]*B[cs][o]
137.                     }
138.                     tp << prob
139.                 }
140.                 α << tp
141.             }
142.         }
143.
144.         double prob=0
145.         def lo = olist[-1]
146.         SS.times{ps->
147.             /*SS.times{cs->
148.                 prob+=(double)α[-1][ps]*A[ps][cs]*B[cs][lo]
149.             }*/
150.             prob+=(double)α[-1][ps]
151.         }
152.         return prob
153.     }
154.
155.     List run(List O)
156.     {
157.         def S = []                          // Best State Transition
158.         def PV = []                         // Previous Viter Result
159.         def PVP = []                        // Previous Viter Result Path
160.         def t=1
161.         double mv=Integer.MIN_VALUE         // max viter value
162.         def ms=0
163.         O.each{ o->
164.             if(shift>0) o-=shift
165.             double cp
166.             mv=Integer.MIN_VALUE
167.             ms=0
168.             def V = []                      // Current Viter Result
169.             def VS = [:]                    // Current Viter/State map
170.             SS.times { s->
171.                 dprintf("\tδ%d(%d)=max P(q%d=%d|o%d=%d)\n", t, s+shift, t, s+shift, t, o+shift)
172.                 if(t>1)
173.                 {
174.                     dprintf "\t\t=max δ%d(i)*ai%d*b%d(%d)\n", t-1, s+shift, s+shift, o+shift
175.                     dprintf "\t\t=max ("
176.                     double tmv=Integer.MIN_VALUE
177.                     def tms=0
178.                     SS.times {
179.                         def tv = PV[it]+logSum(A[it][s],B[s][o])
180.                         //def tv = PV[it]*A[it][s]*B[s][o]
181.                         dprintf "δ%d(%d)*a%d%d*b%d(%d)=%.04f ", t-1, it+shift, it+shift, s+shift, s+shift, o+shift, tv
182.                         if(tv>tmv)
183.                         {
184.                             tmv = tv
185.                             tms = it
186.                         }
187.                     }
188.                     VS[s]=tms
189.                     cp = tmv
190.                     dprintf ")=%.03f (%d)\n", cp, tms
191.                 }
192.                 else
193.                 {
194.                     cp = logSum(π[s],B[s][o])
195.                     //cp = π[s]*B[s][o]
196.                     dprintf "\t\t=π%d b%d(%d)=%.03f\n", s+shift, s+shift, o+shift, cp
197.                 }
198.                 V[s] = cp
199.                 if(cp>mv)
200.                 {
201.                     mv = cp
202.                     ms = s
203.                 }
204.             }
205.             dprintf "\tPickup S=%d at t(%d)!\n", ms+shift, t
206.             PVP << VS
207.             PV=V
208.             t++
209.         }
210.         S << ms+shift
211.         for(int i=(PVP.size()-1); i>=1; i--)
212.         {
213.             ms = PVP[i][ms]
214.             S << ms+shift
215.         }
216.         S = S.reverse()
217.         dprintf "\tFinal State Sequence: %s (Prop.=%.09f)\n", S.join("->"), Math.exp(mv)
218.         return S
219.     }
220.
221.     List run2(List O)
222.     {
223.         def S = []                          // Best State Transition
224.         def PV = []                         // Previous Viter Result
225.         def PVP = []                        // Previous Viter Result Path
226.         def t=1
227.         double mv=Integer.MIN_VALUE         // max viter value
228.         def ms=0
229.         O.each{ o->
230.             if(shift>0) o-=shift
231.             double cp
232.             mv=Integer.MIN_VALUE
233.             ms=0
234.             def V = []                      // Current Viter Result
235.             def VS = [:]                    // Current Viter/State map
236.             SS.times { s->
237.                 dprintf("\tδ%d(%d)=max P(q%d=%d|o%d=%d)\n", t, s+shift, t, s+shift, t, o+shift)
238.                 if(t>1)
239.                 {
240.                     dprintf "\t\t=max δ%d(i)*ai%d*b%d(%d)\n", t-1, s+shift, s+shift, o+shift
241.                     dprintf "\t\t=max ("
242.                     double tmv=0
243.                     def tms=0
244.                     SS.times {
245.                         def tv = PV[it]*A[it][s]*B[s][o]
246.                         dprintf "δ%d(%d)*a%d%d*b%d(%d)=%.04f ", t-1, it+shift, it+shift, s+shift, s+shift, o+shift, tv
247.                         if(tv>tmv)
248.                         {
249.                             tmv = tv
250.                             tms = it
251.                         }
252.                     }
253.                     VS[s]=tms
254.                     cp = tmv
255.                     dprintf ")=%.03f (%d)\n", cp, tms+shift
256.                 }
257.                 else
258.                 {
259.                     cp = π[s]*B[s][o]
260.                     dprintf "\t\t=π%d b%d(%d)=%.03f\n", s+shift, s+shift, o+shift, cp
261.                 }
262.                 V[s] = cp
263.                 if(cp>mv)
264.                 {
265.                     mv = cp
266.                     ms = s
267.                 }
268.             }
269.             dprintf "\tPickup S=%d at t(%d)!\n", ms+shift, t
270.             PVP << VS
271.             PV=V
272.             t++
273.         }
274.
275.         //printf "\t[Test] %d\n", ms
276.         S << ms+shift
277.         for(int i=(PVP.size()-1); i>=1; i--)
278.         {
279.             ms = PVP[i][ms]
280.             //printf "\t[Test] %d (%d)\n", ms, i
281.             S << ms+shift
282.         }
283.         S = S.reverse()
284.         dprintf "\tFinal State Sequence: %s (Prop.=%.06f)\n", S.join("->"), mv
285.         return S
286.     }
287. }

1. def A = []
2. def B = []
3. def π = []
4. def SM = [:]
5. def OM = [:]
6. SM[0]='Sunny'
7. SM[1]='Rainy'
8. OM[0]='Walk'
9. OM[1]='Shop'
10. OM[2]='Clean'
11. π[0] = 0.4; π[1] = 0.6
12. A[0] = [0.60.4]
13. A[1] = [0.30.7]
14. B[0] = [0.60.30.1]
15. B[1] = [0.10.40.5]
16.
17. ViterbiAlg va = new ViterbiAlg()
18. va.setPM(A, B, π)
19.
20. def O = [0012// Walk > Walk > Shop > Clean
21. def S = va.run(O)
22. printf "%s\n", S.collect { SM[it]}.join(" > ")

Sunny > Sunny > Rainy > Rainy

1. 你的transition matrix寫錯了

