티스토리 뷰

TensorFlow.js는 브라우저 및 Node.js에서 머신러닝 모델을 개발하고 실행할 수 있는 JavaScript 라이브러리입니다. TFJS의 주요 특징과 간단한 예제( "손글씨 숫자 인식" )를 통해 알아보겠습니다.

 

1. TFJS의 주요 특징

  • 브라우저 기반 실행: 클라이언트 사이드에서 직접 모델을 학습하고 예측할 수 있습니다.
  • 모델 변환: 기존 TensorFlow 모델을 TensorFlow.js 형식으로 변환하여 웹에서 실행 가능.
  • 빠른 프로토타이핑: 웹 애플리케이션에서 손쉽게 머신러닝 기능을 구현할 수 있습니다.

 

2. 손글씨 숫자 인식

https://github.com/google/tfjs-mnist-workshop/tree/master/model
  • 위 경로에서 "group1-shard1of1" 과 "model.json" 파일을 다운로드합니다.

 

2.1. index.html 작성

<!DOCTYPE html>
<html lang="ko">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>손글씨 숫자 인식</title>
    <style>
        body {
            font-family: 'Arial', sans-serif;
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: center;
            margin: 0;
            padding: 0;
            background-color: #f5f5f5;
        }
        h1 {
            font-size: 30px;
            color: #333;
            margin-bottom: 20px;
        }
        .canvas-container {
            display: flex;
            justify-content: center;
            align-items: center;
            gap: 20px;
            margin-bottom: 20px;
        }
        #canvas {
            border: 2px solid #333;
            background-color: #fff;
            cursor: crosshair;
        }
        #outputCanvas {
			border: 2px solid #ccc;
			margin-left: 20px;
			box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
		}
        button {
            padding: 10px 20px;
            background-color: #4CAF50;
            color: white;
            border: none;
            cursor: pointer;
            border-radius: 5px;
            margin: 10px 0;
            font-size: 16px;
        }
        button:hover {
            background-color: #45a049;
        }
        #prediction {
            font-size: 20px;
            font-weight: bold;
            margin-top: 20px;
        }
        .flex-container {
            display: flex;
            justify-content: center;
            align-items: center;
            gap: 15px;
        }
    </style>
</head>
<body>
    <h1>손글씨 숫자 인식</h1>

    <!-- 캔버스와 전처리 후 캔버스를 나란히 배치 -->
    <div class="canvas-container">
        <!-- 입력 캔버스 -->
        <canvas id="canvas" width="280" height="280"></canvas>

        <!-- 전처리된 이미지 출력 캔버스 -->
        <canvas id="outputCanvas" width="28" height="28"></canvas>
    </div>

    <!-- 버튼들 -->
    <div class="flex-container">
        <button id="clearButton">캔버스 지우기</button>
        <button id="predictButton">예측하기</button>
    </div>

    <div id="prediction">예측 결과: </div>
    
    <!-- OpenCV.js 로드 -->
    <script async src="https://docs.opencv.org/4.x/opencv.js"></script>
    
    <!-- TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

    <script src = "script.js"></script>
</body>
</html>
  • html 파일로 화면을 구성합니다.

 

2.2. script.js 작성

let model;
let canvas = document.getElementById('canvas');
let ctx = canvas.getContext('2d');
let isDrawing = false;

// 캔버스 초기화
function clearCanvas() {
	ctx.fillStyle = 'white';
	ctx.fillRect(0, 0, canvas.width, canvas.height);
	document.getElementById('prediction').innerText = "Prediction: ";
}

// 드로잉 이벤트 설정
canvas.addEventListener('mousedown', () => isDrawing = true);
canvas.addEventListener('mouseup', () => {
    isDrawing = false;
    ctx.beginPath();
});
canvas.addEventListener('mousemove', draw);

function draw(event) {
	if (!isDrawing) return;
	ctx.lineWidth = 20;
	ctx.lineCap = 'round';
	ctx.strokeStyle = 'black';
	ctx.lineTo(event.offsetX, event.offsetY);
	ctx.stroke();
	ctx.beginPath();
	ctx.moveTo(event.offsetX, event.offsetY);
}

// OpenCV.js가 로드되고 준비가 되면 실행
function onOpenCVReady() {
	console.log('OpenCV.js is ready');
	// 여기서부터 OpenCV를 사용할 수 있습니다.
}

// 모델 로딩
async function loadModel() {
	model = await tf.loadLayersModel('http://localhost:1234/model.json');
	console.log("Model Loaded Successfully");
}

// 예측 기능
async function predict() {
    if (!model) await loadModel();

    // TensorFlow.js에서 CPU로 설정
    tf.setBackend('cpu');  // WebGL을 비활성화하고 CPU로 실행

    // 캔버스 이미지 가져오기
    let imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

    // OpenCV.js로 이진화 처리
    let src = cv.matFromImageData(imageData);
    let gray = new cv.Mat();
    let thresholded = new cv.Mat();

    // 그레이스케일로 변환
    cv.cvtColor(src, gray, cv.COLOR_RGBA2GRAY);
    // 이진화 처리 (배경을 흰색으로 처리)
    cv.threshold(gray, thresholded, 127, 255, cv.THRESH_BINARY_INV);

	// 잡음 제거
	//cv.medianBlur(thresholded, thresholded, 5);  // 5는 커널 크기

	// 엣지 검출 (선택적으로 사용)
	//cv.Canny(thresholded, thresholded, 50, 150);

    // 크기 조정 (28x28)
    let resized = new cv.Mat();
    cv.resize(thresholded, resized, new cv.Size(28, 28), 0, 0, cv.INTER_LINEAR);

    // 그레이스케일 이미지를 RGBA 이미지로 변환
    let rgba = new cv.Mat();
    cv.cvtColor(resized, rgba, cv.COLOR_GRAY2RGBA);

    // TensorFlow.js에서 사용할 수 있도록 ImageData로 변환
    let imageDataResized = new ImageData(new Uint8ClampedArray(rgba.data), rgba.cols, rgba.rows);

    // TensorFlow.js에 입력할 형태로 변환
    let tensor = tf.browser.fromPixels(imageDataResized, 1)  // 그레이스케일 이미지
        .toFloat()
        .expandDims(0)  // 배치 차원 추가 (배치 크기 1)
        .expandDims(-1) // 채널 차원 추가 (배치, 높이, 너비, 채널)
        .div(tf.scalar(255.0))  // 정규화: 0~255 -> 0~1로 변환
		.reshape([1, 784]);

    // 예측 수행
    const prediction = model.predict(tensor);
    const predictedValue = prediction.argMax(1).dataSync()[0];

    // 예측 결과 출력
	console.log("result : "+predictedValue)
    document.getElementById('prediction').innerText = `Prediction: ${predictedValue}`;

    // 전처리된 이미지를 캔버스에 출력
    let outputCanvas = document.getElementById('outputCanvas');
    let outputCtx = outputCanvas.getContext('2d');
    outputCanvas.width = resized.cols;
    outputCanvas.height = resized.rows;

    // 전처리된 이미지 그리기
    let outputImageData = new ImageData(new Uint8ClampedArray(rgba.data), rgba.cols, rgba.rows);
    outputCtx.putImageData(outputImageData, 0, 0);

    // 메모리 해제
    prediction.dispose();
    tensor.dispose();
    src.delete();
    gray.delete();
    thresholded.delete();
    resized.delete();
    rgba.delete();
}

// 버튼 이벤트 설정
document.getElementById('clearButton').addEventListener('click', clearCanvas);
document.getElementById('predictButton').addEventListener('click', predict);

// 페이지 로드 시 모델 로드
loadModel();

// OpenCV.js 로딩 후 초기화 호출
if (typeof cv !== 'undefined') {
	console.log('cv is loaded.');
	onOpenCVReady();
} else {
	document.addEventListener('opencvjsloaded', onOpenCVReady);
}

clearCanvas();
  • 캔버스 드로잉: 사용자가 손으로 그린 이미지를 캔버스에 그리며, 마우스 이벤트를 통해 드로잉을 관리합니다.
  • 이미지 전처리: OpenCV.js를 사용해 그려진 이미지를 그레이스케일로 변환하고, 이진화 및 크기 조정을 통해 28x28 크기의 이미지로 변환합니다.
  • 모델 예측: TensorFlow.js로 로드된 모델을 사용하여 이미지의 예측 결과를 반환합니다.
  • 결과 출력: 예측 결과를 화면에 표시하고, 전처리된 이미지를 출력 캔버스에 그립니다.

 

2.3. 실행

python -m http.server 1234
  • 저는 python을 이용하여 웹 서비스를 실행시켰습니다. 각자 상황에 맞게 진행해 주세요.

  • 학습데이터가 충분하지 않아 정확도가 다소 떨어집니다. 이번 포스팅에서는 사용방법에 포커스를 맞춰 진행하였으니 참고 부탁드립니다.

 

감사합니다.

최근에 올라온 글
Total
Today
Yesterday