densepose简化调用

把densepose官方给的代码做了简化,去除了各种封装和log,只保留核心功能。

使用之前先pip安装densepose包

pip install git+https://github.com/facebookresearch/detectron2@main#subdirectory=projects/DensePose

代码从 densepose/apply_net.py 中抽取:

import torch
import cv2
import numpy as np
import sys
import os

from detectron2.config import CfgNode, get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from detectron2.utils.logger import setup_logger

from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import (
    DensePoseOutputsTextureVisualizer,
    DensePoseOutputsVertexVisualizer,
    get_texture_atlases,
)
from densepose.vis.densepose_results import (
    DensePoseResultsContourVisualizer,
    DensePoseResultsFineSegmentationVisualizer,
    DensePoseResultsUVisualizer,
    DensePoseResultsVVisualizer,
)
from densepose.vis.densepose_results_textures import (
    DensePoseResultsVisualizerWithTexture,
    get_texture_atlas,
)
from densepose.vis.extractor import (
    CompoundExtractor,
    DensePoseOutputsExtractor,
    DensePoseResultExtractor,
    create_extractor,
)

from typing import Any, Dict, List

from types import SimpleNamespace


def setup_config(config_fpath: str, model_fpath: str, args, opts: List[str]):
    cfg = get_cfg()
    add_densepose_config(cfg)
    cfg.merge_from_file(config_fpath)
    cfg.merge_from_list(args.opts)
    if opts:
        cfg.merge_from_list(opts)
    cfg.MODEL.WEIGHTS = model_fpath
    cfg.freeze()
    return cfg


def create_context(args, cfg: CfgNode) -> Dict[str, Any]:
    VISUALIZERS = {
        "dp_contour": DensePoseResultsContourVisualizer,
        "dp_segm": DensePoseResultsFineSegmentationVisualizer,
        "dp_u": DensePoseResultsUVisualizer,
        "dp_v": DensePoseResultsVVisualizer,
        "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
        "dp_cse_texture": DensePoseOutputsTextureVisualizer,
        "dp_vertex": DensePoseOutputsVertexVisualizer,
        "bbox": ScoredBoundingBoxVisualizer,
    }

    vis_specs = ['dp_segm'] #args.visualizations.split(",")
    visualizers = []
    extractors = []
    for vis_spec in vis_specs:
        texture_atlas = get_texture_atlas(args.texture_atlas)
        texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
        vis = VISUALIZERS[vis_spec](
            cfg=cfg,
            texture_atlas=texture_atlas,
            texture_atlases_dict=texture_atlases_dict,
        )
        visualizers.append(vis)
        extractor = create_extractor(vis)
        extractors.append(extractor)
    visualizer = CompoundVisualizer(visualizers)
    extractor = CompoundExtractor(extractors)
    context = {
        "extractor": extractor,
        "visualizer": visualizer,
        "out_fname": args.output,
        "entry_idx": 0,
    }
    return context


'''
image: BGR
'''
def densepose_iuv(image):
    args = SimpleNamespace()
    args.cfg = '/mnt/d/Code/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml'
    args.input = 'a.jpg'
    args.min_score = 0.8
    args.model = '/mnt/d/model_final_162be9.pkl'
    args.nms_thresh = None
    args.opts = []
    args.output = 'a_out.jpg'
    args.texture_atlas = None
    args.texture_atlases_map = None
    args.verbosity = 1
    args.visualize = 'bbox,dp_segm'

    #inference
    opts = []
    cfg = setup_config(args.cfg, args.model, args, opts)
    context = create_context(args, cfg)
    predictor = DefaultPredictor(cfg)

    with torch.no_grad():
        outputs = predictor(image)['instances']

        visualizer = context["visualizer"]
        extractor = context["extractor"]
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
        data = extractor(outputs)
        image_vis = visualizer.visualize(image, data)

    return image_vis

if __name__ == '__main__':
    image = cv2.imread('a.jpg')
    vis = densepose_iuv(image) 

发表评论