Chapter 33: Example 2 Training
What is “Example 2 Training”?
Example 2 Training = the model.fitDataset() (or model.fit()) call when training the convolutional neural network (CNN) on the full MNIST dataset (60,000 training images).
In simple words:
- Before training → model has random weights → it basically guesses randomly (≈10% accuracy — like picking a digit out of 10)
- During training → the model sees all 60,000 images many times (epochs)
- For each batch of images:
- Makes predictions (10 probabilities per image)
- Compares predictions to true labels
- Calculates how wrong it was (categorical crossentropy loss)
- Computes gradients → gently updates millions of weights (backpropagation + Adam optimizer)
- After each epoch → loss drops, accuracy climbs
- After 5–6 epochs → accuracy reaches 97–99% on test set
After training → you can draw your own digit on canvas → model tells you what it thinks it is — and it usually gets it right!
Why This Training Feels So Impressive
- 60,000 images × 28×28 pixels × 6 epochs = millions of operations — all in your browser!
- Loss starts high (~2.3) → drops to ~0.05–0.1
- Accuracy starts ~10–20% → climbs to 97–99%
- Validation accuracy (on unseen test images) follows closely → shows real learning, not memorizing
Full Runnable Code – Example 2 Training (with Live Monitoring)
Save as tfjs-example2-training.html and open in browser (Chrome/Edge/Firefox).
|
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8" /> <title>TensorFlow.js Example 2 Training – MNIST CNN</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@latest"></script> <style> body { font-family: Arial, sans-serif; text-align: center; padding: 40px; background: #f8f9fa; } h1 { color: #4285f4; } button { padding: 14px 32px; font-size: 1.3em; background: #4285f4; color: white; border: none; border-radius: 8px; cursor: pointer; margin: 20px; } button:hover { background: #3367d6; } #status { font-size: 1.5em; margin: 30px; min-height: 60px; color: #444; } #log { background: #1e1e1e; color: #d4d4d4; padding: 20px; border-radius: 10px; max-width: 900px; margin: 20px auto; text-align: left; font-family: monospace; max-height: 400px; overflow-y: auto; white-space: pre-wrap; } canvas { border: 3px solid #4285f4; border-radius: 10px; margin: 30px; background: black; } #prediction { font-size: 3em; margin: 30px; color: #4285f4; font-weight: bold; } </style> </head> <body> <h1>Example 2 Training: Watch CNN Learn to Read Handwritten Digits</h1> <p style="max-width:800px; margin:20px auto; font-size:1.1em;"> Click below → the model trains on 60,000 handwritten digits.<br> Watch loss drop & accuracy climb live in the floating Visor window + log.<br> After training, draw any digit 0–9 → click Predict! </p> <button onclick="trainExample2()">Start Example 2 Training (1–3 min)</button> <div id="status">Waiting for you to start...</div> <div id="log">Training log will appear here...\n</div> <h3>Draw a digit (0–9) here after training</h3> <canvas id="canvas" width="280" height="280"></canvas><br> <button onclick="clearCanvas()">Clear</button> <button onclick="predictDigit()" disabled>Predict</button> <div id="prediction"></div> <script> let model; let isDrawing = false; const canvas = document.getElementById('canvas'); const ctx = canvas.getContext('2d'); ctx.lineWidth = 24; ctx.lineCap = 'round'; ctx.strokeStyle = 'white'; canvas.addEventListener('mousedown', () => isDrawing = true); canvas.addEventListener('mouseup', () => { isDrawing = false; ctx.beginPath(); }); canvas.addEventListener('mousemove', (e) => { if (!isDrawing) return; ctx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); ctx.stroke(); ctx.beginPath(); ctx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); }); function clearCanvas() { ctx.fillStyle = 'black'; ctx.fillRect(0,0,canvas.width,canvas.height); } clearCanvas(); function log(msg) { console.log(msg); document.getElementById('log').innerHTML += msg + '\n'; document.getElementById('log').scrollTop = document.getElementById('log').scrollHeight; } async function trainExample2() { const status = document.getElementById('status'); status.innerHTML = 'Loading 60,000 MNIST training images...'; const {train, test} = await tf.data.mnist(); const trainData = train.map(({xs, ys}) => ({ xs: xs.reshape([28, 28, 1]).div(255.0), ys: ys })).shuffle(1000).batch(32); const testData = test.map(({xs, ys}) => ({ xs: xs.reshape([28, 28, 1]).div(255.0), ys: ys })).batch(32); status.innerHTML = 'Building Example 2 CNN model...'; model = tf.sequential(); // Conv1: learn edges & lines model.add(tf.layers.conv2d({ inputShape: [28, 28, 1], filters: 32, kernelSize: 3, activation: 'relu' })); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); // Conv2: learn shapes & combinations model.add(tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: 'relu' })); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.flatten()); model.add(tf.layers.dense({units: 128, activation: 'relu'})); model.add(tf.layers.dropout({rate: 0.2})); model.add(tf.layers.dense({units: 10, activation: 'softmax'})); model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] }); log("Example 2 Model ready (small CNN)"); model.summary(); status.innerHTML = 'Training... (watch floating Visor window + log below)'; const surface = tfvis.visor().surface({tab: 'Example 2 Training', name: 'Live Metrics'}); await model.fitDataset(trainData, { epochs: 6, validationData: testData, callbacks: tfvis.show.fitCallbacks(surface, ['loss', 'val_loss', 'acc', 'val_acc']) }); status.innerHTML = 'Training finished! Accuracy usually 97–99%'; log("Training complete – open Visor to see loss & accuracy curves"); document.querySelector('button[onclick="predictDigit()"]').disabled = false; } async function predictDigit() { if (!model) return alert("Please train the model first!"); let img = tf.browser.fromPixels(canvas, 1) .resizeBilinear([28, 28]) .mean(2) .expandDims() .expandDims(-1) .div(255.0); const pred = model.predict(img); const digit = (await pred.argMax(-1).data())[0]; document.getElementById('prediction').innerHTML = `Predicted digit: <b>${digit}</b>`; img.dispose(); pred.dispose(); } </script> </body> </html> |
What You Should See During Example 2 Training
- Click “Start Training” → data loads (takes 10–30 sec)
- tfjs-vis Visor opens automatically (floating window or tab)
- Graphs update every epoch:
- Loss: starts ~2.3 → drops to ~0.05–0.1
- Accuracy: starts ~10–30% → climbs to 97–99%
- Validation curves follow closely (good generalization)
- Log shows progress
- After 6 epochs (~1–3 min on decent laptop) → accuracy usually 97.5–99.2%
- Draw any digit → “Predict” → model guesses correctly most times
Typical Training Progress (What Numbers to Expect)
- Epoch 1: loss ≈ 0.4–0.6, acc ≈ 85–90%
- Epoch 3: loss ≈ 0.1–0.15, acc ≈ 96–97%
- Epoch 6: loss ≈ 0.05–0.08, acc ≈ 98–99%
If accuracy stays low → common reasons:
- Too few epochs (try 10 instead of 6)
- Bad drawing (draw slowly, thicker lines)
- Browser using CPU instead of GPU (Chrome usually uses WebGL)
Final Teacher Summary
Example 2 Training = the model.fitDataset() step where the small CNN learns to classify 60,000 handwritten digits.
- Starts almost random (~10% accuracy)
- Sees 60,000 images many times (6 epochs)
- Conv layers learn edges → shapes → digit patterns
- Loss drops fast at first, then slowly fine-tunes
- Ends with 97–99% accuracy — model can now read digits you draw!
This training step is where you first feel real deep learning power in the browser — from guessing randomly to almost perfect digit recognition in minutes.
Understood completely? 🌟
Want next?
- How to improve this model to 99.2%+ accuracy?
- Add confusion matrix to see which digits confuse it most?
- Save/load this trained model & use it in another project?
Just tell me — next class is ready! 🚀
