1. Prepare the environment

  • windows
  • idea
  • jdk11
  • maven

This article uses Model-Zoo Models to run the target detection task. Model-zoo is a website with many deep learning models built by Singaporean Jingyu Koh.

2. New projects

The final directory structure is as follows:

The code address in: examples.javacodegeeks.com/djl-spring-…

Third, pom. XML

<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>ai.djl</groupId>
    <artifactId>image-object-detection</artifactId>
    <version>1.0.0 - the SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <djl.version>0.11.0</djl.version>
    </properties>

    <repositories>
        <repository>
            <id>djl.ai</id>
            <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        </repository>
    </repositories>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <version>2.3.4. RELEASE</version>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.6.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.3</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>RELEASE</version>
            <scope>compile</scope>
        </dependency>
    </dependencies>
    <build>
        <finalName>${project.artifactId}</finalName>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>2.0.1. RELEASE</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
Copy the code

Iv. Source code

1. SpringBoot entrance

package com.jcg.djl;


import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class ImageObjectDetectionApplication {

    public static void main(String[] args) { SpringApplication.run(ImageObjectDetectionApplication.class, args); }}Copy the code

2. Controller

package com.jcg.djl;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.IOUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;

@Slf4j
@RestController
public class ImageDetectController {

    @PostMapping(value = "/upload", produces = MediaType.IMAGE_PNG_VALUE)
    public ResponseEntity<String> diagnose(@RequestParam("file") MultipartFile file) throws ModelException, TranslateException, IOException {
        byte[] bytes = file.getBytes();
        Path imageFile = Paths.get(Objects.requireNonNull(file.getOriginalFilename()));
        Files.write(imageFile, bytes);
        return predict(imageFile);
    }


    public ResponseEntity<String> predict(Path imageFile) throws IOException, ModelException, TranslateException {
        Image img = ImageFactory.getInstance().fromFile(imageFile);

        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.OBJECT_DETECTION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optFilter("backbone"."resnet50")
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
                DetectedObjects detection = predictor.predict(img);
                returnsaveBoundingBoxImage(img, detection); }}}private ResponseEntity<String> saveBoundingBoxImage(Image img, DetectedObjects detection)
            throws IOException {
        Path outputDir = Paths.get("src/main/resources");
        Files.createDirectories(outputDir);

        // Make image copy with alpha channel because original image was jpg
        Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
        newImage.drawBoundingBoxes(detection);

        Path imagePath = outputDir.resolve("detected.png");
        // OpenJDK can't save jpg with alpha channel
        newImage.save(Files.newOutputStream(imagePath), "png");
        log.info("Detected objects image has been saved in:{}" , imagePath);


        String fileDownloadUri = ServletUriComponentsBuilder.fromCurrentContextPath()
                .path("get")
                .toUriString();
        return ResponseEntity.ok(fileDownloadUri);
    }


    @GetMapping( value = "/get", produces = MediaType.IMAGE_PNG_VALUE )
    public @ResponseBody
    byte[] getImageWithMediaType() throws IOException {
        InputStream in = new ClassPathResource( "detected.png").getInputStream();
        returnIOUtils.toByteArray(in); }}Copy the code

3. application.xml

DJL: # set the type of application. Application-type: OBJECT_DETECTION # Set the type of input data. Java awt. Image. BufferedImage # set output data format output - class: Ai. DJL. Modality. CV. Output. DetectedObjects # set a filter to filter your model model - filter: size: 512 # backbone: Parameter: threshold: 0.5 # Only show the predicted result is greater than or equal to 0.5Copy the code

Five, use mode

1. Run the program:

mvn spring-boot:run
Copy the code

2. Open the web page

http://localhost:8080

3. Upload the image to be recognized

4. Download the identification result

Open address:http://localhost:8080/get