# Problem 3: Analyze the Model by Confusion Matrix Problem Description: * put the prediction and true label in cofusion matrix of your splited validation data * describe what you observed Hint: * you can pick up some images and record their probability distributions over 7 classes. ## 範例 **[Note] 請不要直接使用助教的圖來當成作業交上來** ![r3](http://i.imgur.com/nTWMqGn.png) ## TA hour 假設已經訓練了一個不錯的模型,將其預測在validation data上。 <i class="fa fa-diamond"></i> Keywords: `sklearn.metrics.confusion_matrix`, `keras.load_model`, `predict_classes` <div class="highlight"><pre><span></span><span class="ch">#!/usr/bin/env python</span> <span class="c1"># -*- coding: utf-8 -*-</span> <span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">load_model</span> <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">confusion_matrix</span> <span class="kn">from</span> <span class="nn">marcos</span> <span class="kn">import</span> <span class="n">exp_dir</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span> <span class="k">def</span> <span class="nf">plot_confusion_matrix</span><span class="p">(</span><span class="n">cm</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="s1">&#39;Confusion matrix&#39;</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">jet</span><span class="p">):</span> <span class="sd">&quot;&quot;&quot;</span> <span class="sd"> This function prints and plots the confusion matrix.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">cm</span> <span class="o">=</span> <span class="n">cm</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">&#39;float&#39;</span><span class="p">)</span> <span class="o">/</span> <span class="n">cm</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">cm</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cmap</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">()</span> <span class="n">tick_marks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">classes</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">classes</span><span class="p">)</span> <span class="n">thresh</span> <span class="o">=</span> <span class="n">cm</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">/</span> <span class="mf">2.</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])):</span> <span class="n">plt</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="s1">&#39;{:.2f}&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]),</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">&quot;center&quot;</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;white&quot;</span> <span class="k">if</span> <span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">thresh</span> <span class="k">else</span> <span class="s2">&quot;black&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;True label&#39;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;Predicted label&#39;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">main</span><span class="p">():</span> <span class="n">model_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">exp_dir</span><span class="p">,</span><span class="n">store_path</span><span class="p">,</span><span class="s1">&#39;model.h5&#39;</span><span class="p">)</span> <span class="n">emotion_classifier</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="n">model_path</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">precision</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="n">dev_feats</span> <span class="o">=</span> <span class="n">read_dataset</span><span class="p">(</span><span class="s1">&#39;valid&#39;</span><span class="p">)</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">emotion_classifier</span><span class="o">.</span><span class="n">predict_classes</span><span class="p">(</span><span class="n">dev_feats</span><span class="p">)</span> <span class="n">te_labels</span> <span class="o">=</span> <span class="n">get_labels</span><span class="p">(</span><span class="s1">&#39;valid&#39;</span><span class="p">)</span> <span class="n">conf_mat</span> <span class="o">=</span> <span class="n">confusion_matrix</span><span class="p">(</span><span class="n">te_labels</span><span class="p">,</span><span class="n">predictions</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span> <span class="n">plot_confusion_matrix</span><span class="p">(</span><span class="n">conf_mat</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;Angry&quot;</span><span class="p">,</span><span class="s2">&quot;Disgust&quot;</span><span class="p">,</span><span class="s2">&quot;Fear&quot;</span><span class="p">,</span><span class="s2">&quot;Happy&quot;</span><span class="p">,</span><span class="s2">&quot;Sad&quot;</span><span class="p">,</span><span class="s2">&quot;Surprise&quot;</span><span class="p">,</span><span class="s2">&quot;Neutral&quot;</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </pre></div>