Chapter 30: TensorFlow Example 2
What is “TensorFlow.js Example 2”?
Example 2 = training a convolutional neural network (CNN) on the MNIST handwritten digit dataset to classify images of digits 0–9.
- Dataset: 60,000 training images + 10,000 test images (each 28×28 grayscale pixels)
- Goal: Feed a new 28×28 image → model outputs “this is a 5” (or 3, 8, 9, etc.)
- Model type: simple CNN (Conv2D + MaxPooling + Dense layers)
- Output: 10 probabilities (one for each digit) → highest one wins
This is the first time you work with:
- Image data (raw pixels → tensors)
- Convolutional layers (learn edges, shapes, patterns)
- Multi-class classification (10 possible answers)
- Softmax + categorical crossentropy
- High accuracy (97–99%) in the browser — feels like real AI
Why Example 2 Comes Right After Example 1
- Example 1 → linear regression (1 input, 1 output, 2 parameters)
- Example 2 → image classification (28×28×1 input, 10 outputs, thousands of parameters)
- Same core pattern: data → model → compile → fit → predict
- But now you see deep learning power — convolutions, pooling, dropout
Full Runnable Code for Example 2 (Browser – No Server Needed)
Save this as tfjs-example2.html and double-click to open in Chrome/Edge/Firefox.
|
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8" /> <title>TensorFlow.js Example 2 – MNIST Digit Recognition</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@latest"></script> <style> body { font-family: Arial, sans-serif; text-align: center; padding: 40px; background: #f8f9fa; } h1 { color: #4285f4; } button { padding: 14px 32px; font-size: 1.3em; background: #4285f4; color: white; border: none; border-radius: 8px; cursor: pointer; margin: 20px; } button:hover { background: #3367d6; } #status { font-size: 1.5em; margin: 30px; min-height: 60px; color: #444; } #log { background: #1e1e1e; color: #d4d4d4; padding: 20px; border-radius: 10px; max-width: 900px; margin: 20px auto; text-align: left; font-family: monospace; max-height: 400px; overflow-y: auto; white-space: pre-wrap; } canvas { border: 3px solid #4285f4; border-radius: 10px; margin: 30px; background: black; image-rendering: pixelated; } #prediction { font-size: 3em; margin: 30px; color: #4285f4; font-weight: bold; } </style> </head> <body> <h1>TensorFlow.js Example 2: Recognize Handwritten Digits (MNIST)</h1> <p style="max-width:800px; margin:20px auto; font-size:1.1em;"> This is the classic second example in TensorFlow.js tutorials.<br> Train a small CNN → it learns to read handwritten digits 0–9.<br> After training, draw your own digit on the canvas → click Predict! </p> <button onclick="trainModel()">Start Training (1–3 minutes)</button> <div id="status">Waiting for you to start...</div> <div id="log">Training log will appear here...\n</div> <h3>Draw a digit (0–9) here</h3> <canvas id="canvas" width="280" height="280"></canvas><br> <button onclick="clearCanvas()">Clear Canvas</button> <button onclick="predictDigit()" disabled>Predict My Drawing</button> <div id="prediction"></div> <script> let model; let isDrawing = false; const canvas = document.getElementById('canvas'); const ctx = canvas.getContext('2d'); ctx.lineWidth = 24; ctx.lineCap = 'round'; ctx.strokeStyle = 'white'; canvas.addEventListener('mousedown', () => isDrawing = true); canvas.addEventListener('mouseup', () => { isDrawing = false; ctx.beginPath(); }); canvas.addEventListener('mousemove', draw); canvas.addEventListener('mouseleave', () => isDrawing = false); function draw(e) { if (!isDrawing) return; ctx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); ctx.stroke(); ctx.beginPath(); ctx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); } function clearCanvas() { ctx.fillStyle = 'black'; ctx.fillRect(0, 0, canvas.width, canvas.height); } clearCanvas(); function log(msg) { console.log(msg); document.getElementById('log').innerHTML += msg + '\n'; document.getElementById('log').scrollTop = document.getElementById('log').scrollHeight; } async function trainModel() { const status = document.getElementById('status'); status.innerHTML = 'Loading MNIST dataset... (60k train + 10k test images)'; // Load MNIST (official tfjs helper) const {train, test} = await tf.data.mnist(); // Preprocess: normalize 0–255 → 0–1, add channel dimension const trainData = train.map(({xs, ys}) => ({ xs: xs.reshape([28, 28, 1]).div(255.0), ys: ys })).shuffle(1000).batch(32); const testData = test.map(({xs, ys}) => ({ xs: xs.reshape([28, 28, 1]).div(255.0), ys: ys })).batch(32); status.innerHTML = 'Building CNN model...'; // The Example 2 Model – small but powerful CNN model = tf.sequential(); // Conv layer 1: learn basic edges/patterns model.add(tf.layers.conv2d({ inputShape: [28, 28, 1], filters: 32, kernelSize: 3, activation: 'relu' })); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); // Conv layer 2: learn more complex shapes model.add(tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: 'relu' })); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.flatten()); // Dense layers – combine features & decide digit model.add(tf.layers.dense({units: 128, activation: 'relu'})); model.add(tf.layers.dropout({rate: 0.2})); // prevent overfitting model.add(tf.layers.dense({units: 10, activation: 'softmax'})); // 10 classes model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] }); log("Model summary:"); model.summary(); status.innerHTML = 'Training CNN... (watch Visor window + log)'; // Live visualization with tfjs-vis const surface = tfvis.visor().surface({tab: 'Training Progress', name: 'Metrics'}); await model.fitDataset(trainData, { epochs: 6, validationData: testData, callbacks: tfvis.show.fitCallbacks( surface, ['loss', 'val_loss', 'acc', 'val_acc'], {callbacks: ['onEpochEnd']} ) }); status.innerHTML = 'Training finished! Accuracy should be 97–99%'; log("Training complete – open Visor (floating window) to see live loss & accuracy curves"); // Enable prediction document.querySelector('button[onclick="predictDigit()"]').disabled = false; } async function predictDigit() { if (!model) { alert("Train the model first!"); return; } // Convert canvas drawing to 28×28×1 tensor let imgTensor = tf.browser.fromPixels(canvas, 1) .resizeBilinear([28, 28]) .mean(2) // grayscale .expandDims() // batch dimension .expandDims(-1) // channel dimension .div(255.0); // normalize const prediction = model.predict(imgTensor); const predArray = await prediction.argMax(-1).data(); const digit = predArray[0]; document.getElementById('prediction').innerHTML = `I think this is: <b>${digit}</b>`; // Cleanup imgTensor.dispose(); prediction.dispose(); } </script> </body> </html> |
What Happens When You Run This
- Click Train Model → data loads (60k images — takes 10–30 sec)
- Model builds → tfjs-vis Visor opens automatically (floating window)
- Training runs 6 epochs → you see:
- Loss dropping (from ~2.3 → ~0.05–0.1)
- Accuracy rising (from ~10% → 97–99%)
- Validation curves (shows it generalizes)
- After ~1–3 min → accuracy usually 97.5–99%
- Draw any digit 0–9 on canvas (mouse) → click Predict My Drawing → model tells you the digit
Common Results You Should See
- Final test accuracy: 97.5–99.2% (very good for this small CNN)
- Visor graphs: smooth loss decrease, accuracy plateauing near 98–99%
- Your own drawing: if you draw clearly → correct most times
Why This is Called “Example 2”
- Example 1 → linear regression (1 neuron, regression)
- Example 2 → convolutional neural network (multiple layers, classification)
- Same training pattern (fit / fitDataset) → but now on real image data
- Introduces Conv2D, MaxPooling2D, softmax, categoricalCrossentropy — building blocks of modern vision AI
In Hyderabad 2026, this Example 2 is still the second thing every student, developer, and job interviewee runs — because once you see handwritten digits recognized in your browser with 98%+ accuracy, you know you can build real apps (photo tagging, digit recognition in banking apps, medical image tools, etc.).
Understood the jump from Example 1 to Example 2? 🌟
Want next?
- Improve this model to 99.2%+ accuracy?
- Add confusion matrix in Visor after training?
- Save/load this model & use it in another project?
Just tell me — next class is ready! 🚀
