본문 바로가기

Dev/DL

Flask pytorch mnist 손글씨 인식 웹사이트 구동해보기

728x90

Flask와 pytorch로 웹페이지에서 손글씨를 받아 인식하고 label로 업데이터까지 해보자

 

mnist의 데이터셋이 28*28 인것 처럼 웹페이지에서 입력받을 손글씨의 픽셀을 28*28로 하자.

 

우선 28*28 사이즈의 table에 마우스 드래그로 숫자를 그릴수 있도록 html을 짜주자.

<!DOCTYPE html>
<html lang="en">

<head>
    <title>Number Recognition</title>
    
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://code.jquery.com/jquery-3.6.4.min.js"></script>
    <style>
        table {
            border-collapse: collapse;
            width: 160 px;
        }

        table,
        td,
        th {
            border: 1px solid black;
        }

        td {
            padding: 5px
        }

        td.drawn {
            background-color: black;
        }
    </style>
</head>

<body>
    <h2>Number Recognition</h2>
    <h3>Draw one from 0 to 9</h3>
    <input type = "text" id = "text" >
    <button type="button" onclick="updatemodel()">Update</button>
    <button onclick="predictHandwriting()">Predict</button>
    <br><span id="output"></span><br>

    <script>
        const numRows = 28;
        const numCols = 28;
        let drawing = false;
        let binaryNumber = '';

        document.write('<table>');
        for (let i = 0; i < numRows; i++) {
            document.write('<tr>');
            for (let j = 0; j < numCols; j++) {
                document.write('<td onmousedown="startDrawing()" onmouseup="stopDrawing()" onmousemove="cellHovered(this)"></td>');
            }
            document.write('</tr>');
        }
        document.write('</table>');

        function startDrawing() {
            drawing = true;
        }

        function stopDrawing() {
            drawing = false;
            updateNumberInput();
        }

        function cellHovered(cell) {
            if (drawing) {
                cell.classList.add('drawn');
            }
        }

        function updateNumberInput() {
          const numRows = 28;
          const numCols = 28;
          const tableCells = document.querySelectorAll('td');
          binaryNumber = '';

          tableCells.forEach((cell, index) => {
              if (cell.classList.contains('drawn')) {
                  binaryNumber += '1';
              } else {
                  binaryNumber += '0';
              }

              // 현재 인덱스가 행의 끝에 도달하면 ',' 추가
              if ((index + 1) % numCols !== 0) {
                  binaryNumber += ',';
              } else {
                  binaryNumber += '\n';  // 행의 끝에 도달하면 줄 바꿈 추가
              }});}

        function predictHandwriting() {
          $.ajax({
              url: '/predict',
              type: 'POST',
              contentType: 'application/json',  // JSON 형식으로 데이터 보내기
              data: JSON.stringify({ inputdata: binaryNumber }),
              success: function(result) {
                  $('#output').text('Prediction Result: ' + result);
              }});}
        
        function updatemodel() {
          var label = document.getElementById("text").value;
          $.ajax({
              url: '/update',
              type: 'POST',
              contentType: 'application/json',
              data: JSON.stringify({ label: label, inputdata: binaryNumber }),
              success: function(result) {
                  console.log(result); 
              },
              error: function(error) {
                  console.error('Error:', error);
              }});}

    </script>
</body>

</html>

 

 

서버에 28*28 table을 검정이라면 0, 흰색이라면 1로 치환해, binaryNumber를 json 형식으로 보내자.

업데이트때도 필요한 label 데이터 역시 json 형식으로 전송한다. 

 

from flask import Flask, render_template, request, jsonify
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms

app = Flask(__name__)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = CNN()
model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device('cpu')))
model.eval()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 루트 URL에 대한 라우트 추가
@app.route('/')
def index():
    return render_template('index.html')

# '/predict' URL에 대한 라우트 추가
@app.route('/predict', methods=['POST'])
def predict():
    input_data = request.json['inputdata']
    print(input_data)
    image_matrix = process_tensor(input_data)
    image_matrix = image_matrix
    # 예측 수행
    with torch.no_grad():
        prediction = model(image_matrix)
        predicted_digit = torch.argmax(prediction).item()
    # 예측된 숫자를 문자열로 반환
    return str(predicted_digit)

def process_tensor(number_data):
    rows = number_data.strip().split('\n')
    tensor_data = []

    for row in rows:
        values = [float(val) for val in row.strip().split(',')]
        tensor_data.append(values)

    return torch.tensor(tensor_data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

@app.route('/update', methods=['post'])
def update_model():
    input_data = request.json['inputdata']
    label = request.json['label']
    print(label)
    image_matrix = process_tensor(input_data)
    label = torch.tensor([int(label)], dtype = torch.long)

    output = model(image_matrix)
    loss = criterion(output, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    torch.save(model.state_dict(), 'mnist_model.pth')

    return str("model updated successfully")


# 앱 실행
if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=80)

 

 json형식의 label과 inputdata를 각각 /update, /predict 엔트포인트에서 전달받는다

process_tensor에서 해당 데이터를 숫자형식으로 바꿔 28,28 형태의 텐서로 바꾼다

 

원래는 CNN을 통해 해당 데이터셋을 학습시키고자 했지만 cnn을 처리하는 과정에서

자꾸 차원 오류가나 그냥 ReLU의 단일 층 신경망을 이용해다

 

새로 추가될, 기존의 데이터 셋을 따로 관리하지 않고 바로바로 신경망 업데이트하고 휘발시키기에 

관련 과정을 구현해 보는것도 좋을 듯 하다.

728x90