Perform 6-DoF Pose Estimation for Bin Picking Using Deep Learning
This example shows how to perform six degrees-of-freedom (6-DoF) pose estimation by estimating the 3-D position and orientation of machine parts in a bin using RGB-D images and a deep learning network.
6-DoF pose is the rotation and translation of an object in three-dimensional space with respect to a reference frame. 6-DoF pose estimation using images and depth sensors is key to many computer vision systems, such as intelligent bin picking, for robotics and augmented reality. In intelligent bin picking, the 6-DoF pose is the rotation and translation of an object from some canonical pose to the pose observed by a camera capturing the scene. A robotic arm paired with an RGB-D sensor, which captures both an RGB image and a depth measurement, can then use the 6-DoF pose estimation model to simultaneously detect and estimate the position and orientation of the objects with respect to the RGB-D sensor. You can use the estimated poses in downstream tasks, such as planning a path to an object or determining how to best grip an object with a particular robot gripper.
In this example, you perform 6-DoF pose estimation using a pretrained Pose Mask R-CNN network, which is a type of convolutional neural network (CNN) designed for 6-DoF pose estimation [1][2]. The 6-DoF pose consists of a rotation and a translation in three dimensions, stored as a rigidtform3d
object. You then visualize the network predictions and apply geometry-based postprocessing to refine initial pose predictions and evaluate the results against the ground truth pose. Finally, you can optionally train the network using transfer learning on a bin-picking data set.
This example requires the Computer Vision Toolbox™ Model for Pose Mask R-CNN 6-DoF Pose Estimation. You can install the Computer Vision Toolbox Model for Pose Mask R-CNN 6-DoF Object Pose Estimation from Add-On Explorer. For more information about installing add-ons, see Get and Manage Add-Ons. The Computer Vision Toolbox Model for Pose Mask R-CNN 6-DoF Object Pose Estimation requires Deep Learning Toolbox™ and Image Processing Toolbox™.
Load Sample Data
Load a sample RGB image, and its associated depth image and camera intrinsic parameters, into the workspace.
imRGB = imread("image_00050.jpg"); load("depth_00050.mat","imDepth") load("intrinsics_00050.mat","intrinsics")
Load the 3-D point cloud models of the object shapes, which represent machine parts, into the workspace. Use these 3-D models to help visualize 6-DoF pose predictions.
load("pointCloudModels.mat","modelClassNames","modelPointClouds")
The ground truth annotations associated with this sample - bounding boxes, classes, segmentation masks and 6-DoF pose - which are used to train the deep neural network. Load the annotations into the workspace.
load("groundTruth_00050.mat")
Visualize Ground Truth Annotations and Pose
Display the sample RGB image.
figure
imshow(imRGB);
title("RGB Image")
Display the depth image. The depth image shows the distance, or depth, from the camera to each pixel in the scene.
figure
imshow(imDepth);
title("Depth Image")
Visualize Ground Truth Data
Overlay the ground truth masks on the original image using the insertObjectMask
function.
imRGBAnnot = insertObjectMask(imRGB,gtMasks,Opacity=0.5);
Overlay the object bounding boxes on the original image using the insertObjectAnnotation
function.
imRGBAnnot = insertObjectAnnotation(imRGBAnnot,"rectangle",gtBoxes,gtClasses);
Display the ground truth masks and bounding boxes overlaid on the original image.
figure
imshow(imRGBAnnot);
title("Ground Truth Annotations")
Visualize Ground Truth Pose
To visualize the 3-D rotation and translation of the objects, apply pose transformation to a point cloud of each object class, and display the point clouds on the original image. The pose-transformed point clouds, when projected onto the image, align perfectly with the orientations of the machine parts in the scene.
Define the pose colors corresponding to each object class, number of pose colors, number of objects, and create a copy of the RGB image, imGTPose
, to contain the pose-transformed point clouds.
poseColors = ["blue","green","magenta","cyan"]; numPoseColors = length(poseColors); numObj = size(gtBoxes,1); imGTPose = imRGB;
To visualize the ground truth object poses, project the point clouds onto the image using the rigidtform3d
function, and display the projected point clouds overlaid on the sample image.
for gIndex = 1:numObj gClass = string(gtClasses(gIndex)); % Retrieve the 3D object point cloud data. ptCloud = modelPointClouds(modelClassNames == gClass); % Apply the pose transformation to the point cloud. ptCloudGT = pctransform(ptCloud,gtPoses(gIndex)); % Subsample the point clouds to 5% of the original number of points % for a cleaner visualization. ptCloudGT = pcdownsample(ptCloudGT,"random",0.05); % Project the transformed point clouds onto the image using the camera % intrinsic parameters and identity transform for camera pose and position. projectedPoints = world2img(ptCloudGT.Location,rigidtform3d,intrinsics); % Overlay the 2-D projected points over the image. imGTPose = insertMarker(imGTPose,[projectedPoints(:,1) projectedPoints(:,2)], ... "circle",Size=1,Color=poseColors(modelClassNames==gClass)); end figure imshow(imGTPose); title("Ground Truth Poses")
Predict 6-DoF Pose Using Pretrained Pose Mask R-CNN Model
Create a pretrained Pose Mask R-CNN model using the posemaskrcnn
object.
net = posemaskrcnn("resnet50-pvc-parts");
Predict the 6-DoF poses of the machine parts in the image using the predictPose
object function. Specify the prediction confidence threshold Threshold
as 0.5. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA®-enabled NVIDIA® GPU.
[poses,labels,scores,boxes,masks] = predictPose(net, ... imRGB,imDepth,intrinsics,Threshold=0.5, ... ExecutionEnvironment="gpu");
Visualize Predicted Results
Display the 6-DoF pose predictions using projected point clouds, as used to visualize the ground truth poses in the Visualize Ground Truth Pose section of this example. To visualize the predicted object poses, project the point clouds onto the image using the rigidtform3d
function, and display the projected point clouds overlaid on the sample image.
numPreds = size(boxes,1); imPose = imRGB; detPosedPtClouds = cell(1,numPreds); for detIndex = 1:numPreds detClass = string(labels(detIndex)); % Define predicted rotation and translation. detTform = poses(detIndex); detScore = scores(detIndex); % Retrieve the 3-D object point cloud of the predicted object class. ptCloud = modelPointClouds(modelClassNames == detClass); % Transform the 3-D object point cloud using the predicted pose. ptCloudDet = pctransform(ptCloud, detTform); detPosedPtClouds{detIndex} = ptCloudDet; % Subsample the point cloud for cleaner visualization. ptCloudDet = pcdownsample(ptCloudDet,"random",0.05); % Project the transformed point cloud onto the image using the camera % intrinsic parameters and identity transform for camera pose and position. projectedPoints = world2img(ptCloudDet.Location,rigidtform3d,intrinsics); % Overlay the 2-D projected points over the image. imPose = insertMarker(imPose,[projectedPoints(:,1), projectedPoints(:,2)],... "circle",Size=1,Color=poseColors(modelClassNames==detClass)); end
Insert the annotations for the predicted bounding boxes, classes, and confidence scores into the image using the insertObjectAnnotation
function.
LabelScoreStr = compose("%s-%.2f",labels,scores); imPose = insertObjectAnnotation(imPose,"rectangle",boxes,LabelScoreStr);
Display the image with the overlaid predicted bounding boxes, classes, and confidence scores.
figure
imshow(imPose)
title("Pose Mask R-CNN Prediction Results")
By visualizing the Pose Mask R-CNN network predictions in this way, you can see that the network correctly predicts the 2-D bounding boxes and classes, as well as the 3-D translations of the parts in the bin. Because 3-D rotation is a more challenging task, the predictions are a bit noisy. To refine the results, combine the deep learning results with traditional point cloud processing techniques in the Refine Estimated Pose section.
Refine Estimated Pose
Obtain Partial Point Clouds from Depth Image
To improve the accuracy of the predictions from the Pose Mask R-CNN network, combine the predicted poses with traditional geometric processing of the point cloud data. In this section, use the depth image and the camera intrinsic parameters to obtain a point cloud representation of the scene.
Use postprocessing to remove points that are too far away to realistically be part of the bin-picking scene, such as points on a far wall or floor. Specify this reasonable maximum distance from the camera, or maximum depth, as 0.5.
maxDepth = 0.5;
To help remove the surface of the bin on which the PVC parts are resting from the point cloud data, fit a 3-D plane to the bin surface in the scene point cloud. Specify the maximum bin distance and the bin orientation for the plane-fitting procedure. The maximum bin distance is an estimate of the largest distance of the bin from the camera. For simplicity, assume the bin is in a horizontal orientation.
maxBinDistance = 0.0015; binOrientation = [0 0 -1];
Postprocess the raw scene point cloud from the depth image and return point clouds for each detected object by using the helperPostProcessScene
supporting function.
[scenePointCloud,roiScenePtCloud] = helperPostProcessScene(imDepth,intrinsics,boxes,maxDepth, ...
maxBinDistance,binOrientation);
Plot the scene point cloud.
figure pcshow(scenePointCloud,VerticalAxisDir="Down") title("Scene Point Cloud from Depth Image")
Plot one of the object point clouds extracted from the depth image.
figure pcshow(roiScenePtCloud{end}, VerticalAxisDir="Down") title("Object Point Cloud from Depth Image")
Register Point Clouds to Refine Pose Estimation
To refine pose estimation results, perform point cloud registration based on the iterative closest point (ICP) algorithm using the pcregistericp
function. During registration, you align the object poses predicted by the Pose Mask R-CNN model with partial point clouds of objects from the depth image. To reduce the number of point cloud points used and improve computation time, use the pcdownsample
function and set the downsample factor to 0.25.
downsampleFactor = 0.25; registeredPoses = cell(numPreds,1); for detIndex = 1:numPreds % Downsample the object point cloud transformed by the predicted pose. ptCloudDet = pcdownsample(detPosedPtClouds{detIndex},"random",downsampleFactor); % Downsample point cloud obtained from the postprocessed scene depth image. ptCloudDepth = pcdownsample(roiScenePtCloud{detIndex},"random",downsampleFactor); % Run the ICP point cloud registration algorithm with default % parameters. [tform,movingReg] = pcregistericp(ptCloudDet,ptCloudDepth); registeredPoses{detIndex} = tform; end
For simplicity, visualize the registration for a single detected object. In the left subplot, plot the point clouds obtained from the Pose Mask R-CNN predictions in magenta and those from postprocessing of the scene depth image in green in the left subplot. In the right subplot, plot the registered point cloud in magenta and the depth image point cloud in green.
figure subplot(1,2,1) pcshowpair(ptCloudDet,ptCloudDepth) title("Predicted Pose and Depth Image Point Clouds") subplot(1,2,2) pcshowpair(movingReg,ptCloudDepth) title("ICP Registration Result")
Combine Pose Mask R-CNN Predicted Pose with Point Cloud Registration
Visualize the poses predicted by the Pose Mask R-CNN network after the postprocessing and refinement steps, which are now significantly improved. The projected points are now more aligned to the machine parts present in the image, and the rotation results are more accurate.
refinedPoses = cell(1,numPreds); imPoseRefined = imRGB; for detIndex = 1:numPreds detClass = string(labels(detIndex)); % Define predicted rotation and translation. detTform = poses(detIndex); detScore = scores(detIndex); % Rotation and translation from registration of depth point clouds. icpTform = registeredPoses{detIndex}; % Combine the two transforms to return the final pose. combinedTform = rigidtform3d(icpTform.A*detTform.A); refinedPoses{detIndex} = combinedTform; % Retrieve the 3-D object point cloud of the predicted object class. ptCloud = modelPointClouds(modelClassNames == detClass); % Transform the point cloud by the predicted pose. ptCloudDet = pctransform(ptCloud,combinedTform); % Subsample the point cloud for cleaner visualization. ptCloudDet = pcdownsample(ptCloudDet,"random",0.05); % Project the transformed point cloud onto the image using the camera % intrinsic parameters. projectedPoints = world2img(ptCloudDet.Location,rigidtform3d,intrinsics); % Display the projected point cloud points onto the image. imPoseRefined = insertMarker(imPoseRefined,[projectedPoints(:,1), projectedPoints(:,2)],... "circle",Size=1,Color=poseColors(modelClassNames == detClass)); end refinedPoses = cat(1,refinedPoses{:});
Insert the predicted bounding boxes, class, and confidence score labels for each object into the image by using the insertObjectAnnotation
function.
imPoseRefined = insertObjectAnnotation(imPoseRefined,"rectangle",boxes,LabelScoreStr);
Display the image with the overlaid predicted bounding boxes, classes, and confidence scores.
figure
imshow(imPoseRefined)
title("Pose Mask R-CNN + ICP")
Evaluate 6-DoF Pose Prediction
To quantify or evaluate the quality of predicted 6-DoF object poses, use the Chamfer distance [3] to measure the distance between the closest points of a point cloud transformed by the predicted pose and the same point cloud transformed by the ground truth pose. A small Chamfer distance indicates that the predicted and ground truth poses match closely. To demonstrate a closer pose alignment at a shorter Chamfer distance, display all the object point clouds as ground truth (green) and predicted (magenta) poses.
First, calculate the Chamfer distance between the object point clouds in the ground truth and predicted poses by using the helperEvaluatePosePrediction
supporting function.
[distPtCloud,predIndices,gtIndices] = helperEvaluatePosePrediction( ...
modelPointClouds,modelClassNames,boxes,labels,refinedPoses,gtBoxes,gtClasses,gtPoses);
Next, visualize the point cloud pairs in ground truth and predicted poses, and display the Chamfer distance value corresponding to each object pose.
figure for detIndex = 1:length(predIndices) detClass = string(labels(detIndex)); gtIndex = gtIndices(detIndex); % Obtain the point cloud of the predicted object. ptCloudDet = modelPointClouds(modelClassNames == detClass); % Predicted 6-DoF pose with ICP refinement. detTform = refinedPoses(detIndex); % Apply the predicted pose transformation to the predicted object point % cloud. ptCloudTformDet = pctransform(ptCloudDet, detTform); if gtIndex == 0 % The predicted bounding box does not match any % ground truth bounding box (false positive). ptCloudTformGT = pointCloud(single.empty(0,3)); else % Obtain the point cloud of the ground truth object. ptCloudGT = modelPointClouds(modelClassNames == string(gtClasses(gtIndex))); % Apply the ground truth pose transformation. ptCloudTformGT = pctransform(ptCloudGT,gtPoses(gtIndex)); end subplot(2,4,detIndex); pcshowpair(ptCloudTformDet,ptCloudTformGT); title(sprintf("d = %.4f",distPtCloud(detIndex))) end
Prepare Data for Training
Perform transfer learning on a custom bin-picking data set using the Pose Mask R-CNN deep learning network. The data set contains 100 RGB images and depth images of 3-D pipe connectors, generated using Simulink®. The data consists of images of these machine parts lying at random orientations inside a bin, viewed from different angles and under different lighting conditions, and contains four object classes that correspond to objects having I, X, L, or T shapes. The data set additionally contains bounding boxes, object classes, binary masks, and 6-DoF pose information for every object in each image, as well as point cloud data for each of the four object classes.
Download Data Set
Download the data set using the supporting function helperDownloadPVCPartsDataset
. This step requires an internet connection.
datasetUnzipFolder = helperDownloadPVCPartsDataset;
Load and Partition Data
Prior to partitioning the data, set the global random state to ensure a higher reproducibility of results.
rng("default");
Specify the location of the data set and ground truth data.
datasetDir = fullfile(datasetUnzipFolder,"pvcparts100"); gtLocation = fullfile(datasetDir,"GT");
Create a fileDatastore
object by using the helperSimAnnotMATReader
supporting function that reads from the bin-picking data set.
dsRandom = fileDatastore(gtLocation, ...
ReadFcn=@(x)helperSimAnnotMATReader(x,datasetDir));
Split the data set randomly into training and validation sets. Because the total number of images is relatively small, allocate a relatively large percentage (70%) of the data for training, and 30% for validation.
randIndices = randperm(length(dsRandom.Files)); numTrainRandom = round(0.7*length(dsRandom.Files)); dsTrain = subset(dsRandom,randIndices(1:numTrainRandom)); dsVal = subset(dsRandom,randIndices(numTrainRandom+1:end));
Use the data loader to query the RGB image, depth image, object masks, object classes, bounding boxes, poses, and camera intrinsic parameters of first validation sample.
data = preview(dsVal); imRGB = data{1}; imDepth = data{2}; gtMasks = data{5}; gtClasses = data{4}; gtBoxes = data{3}; gtPoses = data{6}; intrinsics = data{7};
Ensure that the queried object classes maintain the correct order.
trainClassNames = categories(gtClasses)
trainClassNames = 4×1 cell
{'I_shape'}
{'X_shape'}
{'L_shape'}
{'T_shape'}
Train Pose Mask R-CNN Network
Train the Pose Mask R-CNN model in two stages:
Train the pretrained Pose Mask R-CNN network on the instance segmentation task, which consists of predicting bounding boxes, class labels and segmentation masks. This first training step initializes the deep neural network weights for the second stage of training.
Train the trained Pose Mask R-CNN network from the previous step on the pose estimation task, which consists of predicting 3-D rotation and translation for each object detected.
Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA®-enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). The two stages of training take approximately 20 minutes and 45 minutes, respectively, on NVIDIA GeForce RTX 3090 GPU.
Estimate Anchor Boxes
To train the Pose Mask R-CNN network, set the doTrain
variable below to true
. To train the network on your custom data set, estimate the anchor box locations which are specific to the data. Estimate the bounding boxes from a subset of 100 training images. Estimate the anchor boxes using the estimateAnchorBoxes
function.
doTrain = false; if doTrain numAnchors = 8; numImgSample = 100; numImgSample = min(numImgSample, length(dsTrain.Files)); sampledBboxes = {}; dsTmp = copy(dsTrain); for ii=1:numImgSample tmpData = read(dsTmp); randSelect = randi(size(tmpData{3},1)); sampledBboxes{end+1} = tmpData{3}(randSelect,:); end blds = boxLabelDatastore(table(sampledBboxes')); [anchorBoxes,meanAnchorIoU] = estimateAnchorBoxes(blds,numAnchors); end
Define Pose Mask R-CNN Network Architecture
Create the Pose Mask R-CNN pose estimation network using the posemaskrcnn
object. Specify the pretrained network created using ResNet-50 as the base network. Specify the class names and the estimated anchor boxes.
if doTrain untrainedNetwork = posemaskrcnn("resnet50-pvc-parts",trainClassNames,anchorBoxes); end
Specify Training Options for Instance Segmentation
Specify network training options for the instance segmentation training stage using the trainingOptions
(Deep Learning Toolbox) function. Train the object detector using the SGDM solver for a maximum of 5 epochs. Specify the ValidationData
name-value argument as the validation data dsVal
. Specify the location to save intermediate network checkpoints.
if doTrain outFolder = fullfile(tempdir,"output","train_posemaskrcnn_mask"); mkdir(outFolder); ckptFolder = fullfile(outFolder, "checkpoints"); mkdir(ckptFolder); disp(outFolder); optionsMask = trainingOptions("sgdm", ... InitialLearnRate=0.0001, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=1, ... LearnRateDropFactor=0.5, ... MaxEpochs=5, ... Plot="none", ... Momentum=0.9, ... MiniBatchSize=1, ... ResetInputNormalization=false, ... ExecutionEnvironment="gpu", ... VerboseFrequency=5, ... ValidationData=dsVal, ... ValidationFrequency=20, ... Plots="training-progress",... CheckpointPath=ckptFolder,... CheckpointFrequency=2,... CheckpointFrequencyUnit="epoch"); end
/tmp/output/train_posemaskrcnn_mask
Train Pose Mask R-CNN for Instance Segmentation
Train the Pose Mask R-CNN network to predict bounding boxes, object classes, and instance segmentation masks using the trainPoseMaskRCNN
function. This training step does not use pose or mask head losses. Save the trained network to a local folder and plot the progress of the training loss over multiple iterations.
if doTrain net = trainPoseMaskRCNN(... dsTrain,untrainedNetwork,"mask",optionsMask); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(fullfile(outFolder,"trainedPoseMaskRCNN-"+modelDateTime+".mat"),"net"); disp(fullfile(outFolder,"trainedPoseMaskRCNN-"+modelDateTime+".mat")) end
Epoch Iteration TimeElapsed LearnRate TrainingLoss TrainingRPNLoss TrainingRMSE TrainingClassLoss TrainingMaskLoss TrainingRLoss TrainingTLoss ValidationLoss ValidationRPNLoss ValidationRMSE ValidationClassLoss ValidationMaskLoss ValidationRLoss ValidationTLoss _____ _________ ___________ _________ ____________ _______________ ____________ _________________ ________________ _____________ _____________ ______________ _________________ ______________ ___________________ __________________ _______________ _______________ 1 1 00:00:40 0.0001 3.6575 2.0999 0.064771 0.80412 0.68873 0 0 3.656 2.1041 0.064585 0.8018 0.68549 0 0 1 20 00:05:26 0.0001 2.8532 2.101 0.033871 0.1678 0.55049 0 0 2.8663 2.0856 0.028553 0.20525 0.54689 0 0 1 40 00:08:11 0.0001 2.6816 2.0566 0.022791 0.18648 0.41571 0 0 2.6817 2.0625 0.023789 0.17854 0.41691 0 0 1 60 00:10:52 0.0001 2.5642 2.0606 0.022521 0.1352 0.34588 0 0 2.5318 2.0249 0.022349 0.14351 0.34107 0 0 2 80 00:13:33 5e-05 2.424 2.0027 0.017956 0.11243 0.29084 0 0 2.4487 2.0111 0.02126 0.12291 0.29346 0 0 2 100 00:16:04 5e-05 2.3739 1.9871 0.018299 0.10889 0.25964 0 0 2.3918 1.989 0.020345 0.11354 0.26897 0 0 2 120 00:18:27 5e-05 2.3243 1.9641 0.022171 0.10137 0.23665 0 0 2.3622 1.987 0.019728 0.10639 0.24901 0 0 2 140 00:20:30 5e-05 2.3016 1.9652 0.017714 0.089053 0.22962 0 0 2.3231 1.9708 0.019083 0.10094 0.23229 0 0 3 160 00:22:34 2.5e-05 2.3177 1.9933 0.019636 0.090793 0.21396 0 0 2.297 1.962 0.01866 0.095371 0.22095 0 0 3 180 00:24:32 2.5e-05 2.2961 1.9604 0.017988 0.10632 0.21138 0 0 2.2766 1.9511 0.018349 0.093786 0.21333 0 0 3 200 00:26:49 2.5e-05 2.2735 1.9666 0.019381 0.079479 0.20799 0 0 2.266 1.9501 0.017941 0.091677 0.20628 0 0 4 220 00:29:09 1.25e-05 2.1984 1.9292 0.014187 0.066644 0.18836 0 0 2.2479 1.9421 0.017586 0.088471 0.19976 0 0 4 240 00:31:12 1.25e-05 2.1725 1.9005 0.013352 0.07207 0.18658 0 0 2.233 1.9326 0.017474 0.086695 0.19629 0 0 4 260 00:33:40 1.25e-05 2.2302 1.9573 0.017712 0.075043 0.18019 0 0 2.2362 1.9395 0.01735 0.086148 0.1932 0 0 4 280 00:35:41 1.25e-05 2.1736 1.9141 0.013931 0.064661 0.1809 0 0 2.2202 1.928 0.017174 0.084905 0.19013 0 0 5 300 00:37:44 6.25e-06 2.2396 1.9665 0.018019 0.072635 0.18246 0 0 2.227 1.938 0.016975 0.083989 0.18807 0 0 5 320 00:39:41 6.25e-06 2.1851 1.9053 0.016149 0.081836 0.18185 0 0 2.2075 1.9205 0.016835 0.083628 0.18653 0 0 5 340 00:41:56 6.25e-06 2.2176 1.9552 0.017416 0.06266 0.18237 0 0 2.2122 1.9274 0.01673 0.083202 0.18486 0 0
5 350 00:42:49 6.25e-06 2.2211 1.9671 0.013468 0.064305 0.17621 0 0 2.2171 1.9334 0.016688 0.082739 0.18426 0 0
/tmp/output/train_posemaskrcnn_pose-and-mask/trainedPoseMaskRCNN-2024-01-13-17-12-55.mat
Specify Training Options for Pose Prediction
Specify network training options for the pose prediction training stage using the trainingOptions
(Deep Learning Toolbox) function. Train the object detector using the SGDM solver for a maximum of 5 epochs. Specify the ValidationData
name-value argument as the validation data dsVal
. Specify the location to save intermediate network checkpoints.
if doTrain outFolder = fullfile(tempdir,"output","train_posemaskrcnn_pose-and-mask"); mkdir(outFolder); ckptFolder = fullfile(outFolder,"checkpoints"); mkdir(ckptFolder); disp(outFolder); optionsPose = trainingOptions("sgdm", ... InitialLearnRate=0.0001, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=1, ... LearnRateDropFactor=0.5, ... MaxEpochs=5, ... Plot="none", ... Momentum=0.9, ... MiniBatchSize=1, ... ResetInputNormalization=false, ... ExecutionEnvironment="gpu", ... VerboseFrequency=5, ... CheckpointPath=ckptFolder,... CheckpointFrequency=5,... CheckpointFrequencyUnit="epoch",... ValidationData=dsVal, ... ValidationFrequency=20, ... Plots="training-progress"); end
/tmp/output/train_posemaskrcnn_pose-and-mask
Train Pose Mask R-CNN for Pose Estimation
To predict 6-DoF pose for every detected object, train the Pose Mask R-CNN network trained for instance segmentation, net
, with both pose and mask losses using the trainPoseMaskRCNN
function. Save the trained network to a local folder and plot the progress of the training loss for rotation and translation over multiple iterations.
if doTrain trainedNet = trainPoseMaskRCNN(... dsTrain,net,"pose-and-mask",optionsPose); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(fullfile(outFolder,"trainedPoseMaskRCNN-"+modelDateTime+".mat"),"trainedNet"); disp(fullfile(outFolder,"trainedPoseMaskRCNN-"+modelDateTime+".mat")) end
Epoch Iteration TimeElapsed LearnRate TrainingLoss TrainingRPNLoss TrainingRMSE TrainingClassLoss TrainingMaskLoss TrainingRLoss TrainingTLoss ValidationLoss ValidationRPNLoss ValidationRMSE ValidationClassLoss ValidationMaskLoss ValidationRLoss ValidationTLoss _____ _________ ___________ _________ ____________ _______________ ____________ _________________ ________________ _____________ _____________ ______________ _________________ ______________ ___________________ __________________ _______________ _______________ 1 1 00:00:06 0.0001 4.425 1.9067 0.012332 0.07459 0.18954 0.21684 0.0073509 4.5597 1.9297 0.016825 0.08355 0.18645 0.22715 0.007171 1 5 00:01:30 0.0001 3.3443 1.8693 0.013716 0.080874 0.18584 0.11243 0.0070213 1 10 00:01:53 0.0001 4.9746 1.974 0.016091 0.07546 0.18325 0.2653 0.0072887 1 15 00:02:17 0.0001 3.9358 1.8795 0.016757 0.063235 0.1757 0.17371 0.0063496 1 20 00:02:40 0.0001 3.4181 2.0059 0.013514 0.092078 0.18308 0.10455 0.0077985 4.1825 1.9166 0.016276 0.077023 0.17267 0.19286 0.0071258 1 25 00:04:09 0.0001 4 1.8551 0.01647 0.10322 0.16983 0.17885 0.0066904 1 30 00:04:52 0.0001 2.7365 1.8988 0.016118 0.07185 0.15666 0.052364 0.006943 1 35 00:05:33 0.0001 4.211 1.9617 0.016341 0.063856 0.16036 0.19293 0.0079464 1 40 00:06:03 0.0001 3.876 1.9508 0.015947 0.071591 0.15336 0.16067 0.0077663 4.0866 1.8918 0.015256 0.07096 0.15351 0.18839 0.0071143 1 45 00:07:10 0.0001 4.0012 1.8485 0.018498 0.061303 0.14996 0.18584 0.0064536 1 50 00:07:31 0.0001 4.2271 1.8272 0.013666 0.087144 0.15059 0.20833 0.0065178 1 55 00:07:54 0.0001 4.9432 1.8978 0.013188 0.075847 0.13748 0.27551 0.0063785 1 60 00:08:18 0.0001 4.264 1.8711 0.012996 0.051528 0.14363 0.21028 0.0081964 3.979 1.8831 0.014396 0.063456 0.13634 0.18119 0.0069832 1 65 00:09:19 0.0001 3.5047 1.8438 0.015684 0.05031 0.13383 0.13927 0.0068367 1 70 00:09:43 0.0001 4.3177 1.8502 0.016837 0.084931 0.13774 0.21456 0.0082395 2 75 00:10:10 5e-05 3.6337 1.8203 0.011019 0.063082 0.12334 0.155 0.0066063 2 80 00:10:35 5e-05 4.0287 1.9151 0.013735 0.062149 0.12865 0.18343 0.0074681 3.9288 1.8685 0.013519 0.060119 0.12383 0.17935 0.0069387 2 85 00:11:33 5e-05 4.3161 1.812 0.012129 0.061233 0.12055 0.22469 0.0063328 2 90 00:11:56 5e-05 3.1033 1.917 0.011359 0.050071 0.12942 0.092577 0.0069659 2 95 00:12:21 5e-05 3.7234 1.8451 0.013112 0.057168 0.12521 0.16207 0.0062087 2 100 00:12:45 5e-05 3.1449 1.9157 0.013601 0.065032 0.11772 0.096688 0.006601 3.9312 1.8625 0.01309 0.058631 0.11685 0.18111 0.006914 2 105 00:13:48 5e-05 4.0344 1.8717 0.012813 0.05403 0.11965 0.19016 0.0074705 2 110 00:14:14 5e-05 3.8351 1.8718 0.013196 0.055865 0.1164 0.16986 0.0079226 2 115 00:14:38 5e-05 3.9524 1.788 0.014964 0.065013 0.11283 0.19099 0.0061608 2 120 00:15:04 5e-05 4.1302 1.835 0.0096374 0.058363 0.11195 0.2052 0.0063144 3.9014 1.8535 0.012624 0.055436 0.11169 0.17996 0.0068592 2 125 00:16:07 5e-05 4.8263 1.7608 0.010414 0.045687 0.10477 0.28372 0.0067439 2 130 00:16:32 5e-05 4.0135 1.8345 0.011485 0.043757 0.11138 0.19359 0.0076483 2 135 00:17:02 5e-05 3.2582 1.8583 0.014081 0.043283 0.10866 0.11662 0.0067626 2 140 00:17:32 5e-05 4.2529 1.8885 0.016219 0.084695 0.11235 0.20741 0.0077019 3.9015 1.8483 0.012281 0.056084 0.10686 0.18095 0.0068519 3 145 00:18:56 2.5e-05 3.7113 1.8038 0.010348 0.054806 0.10584 0.16745 0.0061971 3 150 00:19:31 2.5e-05 3.819 1.8661 0.012479 0.051198 0.10875 0.17103 0.0070119 3 155 00:20:24 2.5e-05 4.2989 1.8207 0.011249 0.052029 0.10203 0.22564 0.0056505 3 160 00:20:52 2.5e-05 3.0749 1.9308 0.010262 0.033757 0.11216 0.091426 0.0073638 3.9042 1.8464 0.012248 0.054448 0.10394 0.1819 0.0068128 3 165 00:22:09 2.5e-05 3.6398 1.7523 0.011382 0.049113 0.10725 0.16621 0.0057626 3 170 00:22:36 2.5e-05 3.1996 1.87 0.010457 0.045865 0.10051 0.11107 0.0062097 3 175 00:23:05 2.5e-05 3.9704 1.8259 0.011135 0.047548 0.10321 0.19149 0.0067736 3 180 00:23:32 2.5e-05 3.802 1.9253 0.012872 0.053294 0.10632 0.16245 0.0079684 3.89 1.8396 0.012118 0.054609 0.10171 0.18143 0.0067744 3 185 00:24:47 2.5e-05 3.8619 1.856 0.013456 0.058048 0.099451 0.17727 0.0062281 3 190 00:25:13 2.5e-05 4.1406 1.7998 0.0088859 0.064439 0.10199 0.21026 0.0062842 3 195 00:25:42 2.5e-05 4.7698 1.8287 0.0097723 0.042307 0.092342 0.27329 0.0063759 3 200 00:26:08 2.5e-05 4.0045 1.7964 0.010865 0.041261 0.10155 0.19907 0.0063662 3.8773 1.827 0.011879 0.053423 0.099519 0.18179 0.0067648 3 205 00:27:23 2.5e-05 3.189 1.842 0.012717 0.035845 0.10007 0.11339 0.0064482 3 210 00:27:54 2.5e-05 4.1905 1.8187 0.015489 0.080467 0.10085 0.21014 0.007362 4 215 00:28:22 1.25e-05 3.7177 1.7929 0.011466 0.058451 0.095559 0.16988 0.0060606 4 220 00:28:52 1.25e-05 3.9055 1.9303 0.012267 0.052593 0.10119 0.1732 0.0077092 3.8889 1.8304 0.011744 0.052941 0.097746 0.18292 0.0066802 4 225 00:30:00 1.25e-05 4.2065 1.7322 0.011028 0.049974 0.095855 0.22607 0.0056673 4 230 00:30:26 1.25e-05 2.9952 1.8656 0.0097368 0.043235 0.10368 0.089534 0.0077586 4 235 00:30:51 1.25e-05 3.5719 1.7539 0.010217 0.050412 0.10007 0.15977 0.0059596 4 240 00:31:16 1.25e-05 3.2121 1.8378 0.0099281 0.051963 0.094783 0.11535 0.0064121 3.9073 1.8477 0.011723 0.052852 0.096757 0.18315 0.0066667 4 245 00:32:24 1.25e-05 3.9596 1.8314 0.010207 0.043314 0.098772 0.19086 0.0067334 4 250 00:32:50 1.25e-05 3.8413 1.8361 0.012347 0.050678 0.098378 0.17672 0.0076649 4 255 00:33:14 1.25e-05 3.796 1.8042 0.014324 0.061713 0.096777 0.17557 0.0063271 4 260 00:33:33 1.25e-05 4.0321 1.8212 0.0076936 0.060508 0.095771 0.19901 0.0056828 3.8827 1.83 0.011685 0.052384 0.096111 0.18256 0.0066912 4 265 00:34:21 1.25e-05 4.6725 1.8429 0.010168 0.040001 0.089355 0.26294 0.0060712 4 270 00:34:41 1.25e-05 3.9771 1.8056 0.0099264 0.03906 0.097494 0.1963 0.0062003 4 275 00:35:01 1.25e-05 3.1871 1.8461 0.011912 0.036101 0.096022 0.11311 0.006594 4 280 00:35:21 1.25e-05 4.1018 1.8518 0.014761 0.076527 0.096452 0.19914 0.007096 3.888 1.828 0.011568 0.052635 0.095317 0.18335 0.0067016 5 285 00:36:27 6.25e-06 3.7842 1.781 0.011226 0.067673 0.094224 0.17673 0.0062813 5 290 00:36:46 6.25e-06 3.8726 1.9327 0.011989 0.049069 0.097751 0.1703 0.0078037 5 295 00:37:06 6.25e-06 4.2542 1.7398 0.01168 0.045045 0.094381 0.23059 0.0057382 5 300 00:37:24 6.25e-06 2.904 1.8251 0.009817 0.037831 0.10111 0.085492 0.0075196 3.8797 1.8325 0.011495 0.052319 0.094561 0.18224 0.006641 5 305 00:38:12 6.25e-06 3.5168 1.7686 0.0099903 0.044795 0.097054 0.1539 0.0057348 5 310 00:38:31 6.25e-06 3.1919 1.8286 0.010338 0.051687 0.093272 0.11444 0.006366 5 315 00:38:50 6.25e-06 3.9709 1.821 0.009547 0.045288 0.09482 0.19314 0.0068896 5 320 00:39:14 6.25e-06 3.7801 1.7907 0.012333 0.052526 0.096638 0.17514 0.0076521 3.8794 1.8182 0.011481 0.052753 0.094354 0.18359 0.0066794 5 325 00:40:12 6.25e-06 3.8643 1.8934 0.015141 0.064857 0.094833 0.1734 0.0062066 5 330 00:40:37 6.25e-06 3.9124 1.7465 0.0081929 0.064391 0.094037 0.19389 0.0060411 5 335 00:41:02 6.25e-06 4.597 1.7798 0.0099351 0.050266 0.086528 0.26107 0.0059784 5 340 00:41:29 6.25e-06 3.8049 1.7888 0.010299 0.034911 0.095508 0.1814 0.0061398 3.8842 1.8352 0.011452 0.052434 0.093791 0.18248 0.0066486 5 345 00:42:36 6.25e-06 3.0303 1.7334 0.012292 0.038415 0.094794 0.10832 0.0068153
5 350 00:43:04 6.25e-06 4.0497 1.8327 0.014771 0.09441 0.094598 0.19413 0.0071972 3.8757 1.8281 0.011435 0.052273 0.093563 0.18237 0.0066686
/tmp/output/train_posemaskrcnn_pose-and-mask/trainedPoseMaskRCNN-2024-01-13-18-17-09.mat
Supporting Functions
helperSimAnnotMATReader
function out = helperSimAnnotMATReader(filename,datasetRoot) % Read annotations for simulated bin picking dataset % Expected folder structure under `datasetRoot`: % depth/ (depth images folder) % GT/ (ground truth MAT files) % image/ (color images) data = load(filename); groundTruthMaT = data.groundTruthMaT; clear data; % Read RGB image. tmpPath = strsplit(groundTruthMaT.RGBImagePath, '\'); basePath = {tmpPath{4:end}}; imRGBPath = fullfile(datasetRoot, basePath{:}); im = imread(imRGBPath); im = rgb2gray(im); if(size(im,3)==1) im = repmat(im, [1 1 3]); end % Read depth image. aa = strsplit(groundTruthMaT.DepthImagePath, '\'); bb = {aa{4:end}}; imDepthPath = fullfile(datasetRoot,bb{:}); % handle windows paths imD = load(imDepthPath); imD = imD.depth; % For "undefined" value in instance labels, assign to the first class. undefinedSelect = isundefined(groundTruthMaT.instLabels); classNames = categories(groundTruthMaT.instLabels); groundTruthMaT.instLabels(undefinedSelect) = classNames{1}; if isrow(groundTruthMaT.instLabels) groundTruthMaT.instLabels = groundTruthMaT.instLabels'; end % Wrap the camera parameter matrix into the cameraIntrinsics object. K = groundTruthMaT.IntrinsicsMatrix; intrinsics = cameraIntrinsics([K(1,1) K(2,2)], [K(1,3) K(2,3)], [size(im,1) size(im, 2)]); % Process rotation matrix annotations. rotationCell = num2cell(groundTruthMaT.rotationMatrix,[1 2]); rotationCell = squeeze(rotationCell); % Process translation annotations. translation = squeeze(groundTruthMaT.translation)'; translationCell = num2cell(translation, 2); % Process poses into rigidtform3d vector - transpose R to obtain the % correct pose. poseCell = cellfun( @(R,t)(rigidtform3d(R',t)), rotationCell, translationCell, ... UniformOutput=false); pose = vertcat(poseCell{:}); % Remove heavily occluded objects. if length(groundTruthMaT.occPercentage) == length(groundTruthMaT.instLabels) visibility = groundTruthMaT.occPercentage; visibleInstSelector = visibility > 0.5; else visibleInstSelector = true([length(groundTruthMaT.instLabels),1]); end out{1} = im; out{2} = imD; % HxWx1 double array depth-map out{3} = groundTruthMaT.instBBoxes(visibleInstSelector,:); % Nx4 double bounding boxes out{4} = groundTruthMaT.instLabels(visibleInstSelector); % Nx1 categorical object labels out{5} = logical(groundTruthMaT.instMasks(:,:,visibleInstSelector)); % HxWxN logical mask arrays out{6} = pose(visibleInstSelector); % Nx1 rigidtform3d vector of object poses out{7} = intrinsics; % cameraIntrinsics object end
helperPostProcessScene
function [scenePtCloud,roiScenePtCloud] = helperPostProcessScene(imDepth,intrinsics,boxes,maxDepth,maxBinDistance,binOrientation) % Convert the depth image into an organized point cloud using camera % intrinsics. scenePtCloud = pcfromdepth(imDepth,1.0,intrinsics); % Remove outliers, or points that are too far away to be in the bin. selectionROI = [... scenePtCloud.XLimits(1) scenePtCloud.XLimits(2) ... scenePtCloud.YLimits(1) scenePtCloud.YLimits(2) ... scenePtCloud.ZLimits(1) maxDepth]; selectedIndices = findPointsInROI(scenePtCloud, selectionROI); cleanScenePtCloud = select(scenePtCloud,selectedIndices); % Fit a plane to the bin surface. [~,~,outlierIndices] = pcfitplane(... cleanScenePtCloud,maxBinDistance,binOrientation); % Re-map indices back to the original scene point cloud. Use this % when cropping out object detections from the scene point cloud. origPtCloudSelection = selectedIndices(outlierIndices); % Crop predicted ROIs from the scene point cloud. numPreds = size(boxes,1); roiScenePtCloud = cell(1,numPreds); for detIndex=1:numPreds box2D = boxes(detIndex,:); % Get linear indices into the organized point cloud corresponding to the % predicted 2-D bounding box of an object. boxIndices = (box2D(2):box2D(2)+box2D(4))' + (size(scenePtCloud.Location,1)*(box2D(1)-1:box2D(1)+box2D(3)-1)); boxIndices = uint32(boxIndices(:)); % Remove points that are outliers from earlier pre-processing steps % (either belonging to the bin surface or too far away). keptIndices = intersect(origPtCloudSelection,boxIndices); roiScenePtCloud{detIndex} = select(scenePtCloud,keptIndices); end end
helperEvaluatePosePrediction
function [distADDS,predIndices,gtIndices] = helperEvaluatePosePrediction(... modelPointClouds, modelClassNames,boxes,labels,pose, gBox,gLabel,gPose) % Compare predicted and ground truth pose for a single image containing multiple % object instances using the one-sided Chamfer distance. function pointCloudADDS = pointCloudChamferDistance(ptCloudGT,ptCloudDet,numSubsampledPoints) % Return the one-sided Chamfer distance between two point clouds, which % computes the closest point in point cloud B for each point in point cloud A, % and averages over these minimum distances. % Sub-sample the point clouds if nargin == 2 numSubsampledPoints = 1000; end rng("default"); % Ensure reproducibility in the point-cloud sub-sampling step. if numSubsampledPoints < ptCloudDet.Count subSampleFactor = numSubsampledPoints / ptCloudDet.Count; ptCloudDet = pcdownsample(ptCloudDet,"random",subSampleFactor); subSampleFactor = numSubsampledPoints / ptCloudGT.Count; ptCloudGT = pcdownsample(ptCloudGT,"random",subSampleFactor); end % For each point in GT ptCloud, find the distance to closest point in predicted ptCloud. distPtCloud = pdist2(ptCloudGT.Location, ptCloudDet.Location,... "euclidean", "smallest",1); % Average over all points in GT ptCloud. pointCloudADDS = mean(distPtCloud); end maxADDSThreshold = 0.1; % Associate predicted bboxes with ground truth annotations based on % bounding box overlaps as an initial step. minOverlap = 0.1; overlapRatio = bboxOverlapRatio(boxes,gBox); [predMatchScores, predGTIndices] = max(overlapRatio, [], 2); % (numPreds x 1) [gtMatchScores, ~] = max(overlapRatio, [], 1); % (1 x numGT) matchedPreds = predMatchScores > minOverlap; matchedGTs = gtMatchScores > minOverlap; numPreds = size(boxes,1); distADDS = cell(numPreds,1); predIndices = cell(numPreds,1); gtIndices = cell(numPreds,1); for detIndex=1:numPreds detClass = string(labels(detIndex)); % Account for predictions unmatched with GT (false positives). if ~matchedPreds(detIndex) % If the predicted bounding box does not overlap any % ground truth bounding box, then maximum penalty is applied % and the point cloud matching steps are skipped. distADDS{detIndex} = maxADDSThreshold; predIndices{detIndex} = detIndex; gtIndices{detIndex} = 0; else % Match GT labels to Predicted objects by their bounding % box overlap ratio (box Intersection-over-Union). gtIndex = predGTIndices(detIndex); detClassname = string(detClass); gClassname = string(gLabel(gtIndex)); if detClassname ~= gClassname % If predicted object category is incorrec, set % to maximum allowed distance (highly penalized). distADDS{detIndex} = maxADDSThreshold; else % Predicted rotation and translation. detTform = pose(detIndex); % Ground truth pose. gTform = gPose(gtIndex); % Get the point cloud of the object. ptCloud = modelPointClouds(modelClassNames == string(gClassname)); % Apply the ground truth pose transformation. ptCloudTformGT = pctransform(ptCloud, gTform); % Apply the predicted pose transformation ptCloudDet = pctransform(ptCloud, detTform); pointCloudADDSObj = pointCloudChamferDistance(... ptCloudTformGT,ptCloudDet); distADDS{detIndex} = pointCloudADDSObj; end predIndices{detIndex} = detIndex; gtIndices{detIndex} = gtIndex; end end distADDS = cat(1, distADDS{:}); % Account for unmatched GT objects (false negatives). numUnmatchedGT = numel(matchedGTs) - nnz(matchedGTs); if numUnmatchedGT > 0 % Set to max distance for unmatched GTs. falseNegativesADDS = maxADDSThreshold * ones(numUnmatchedGT,1); fnPred = zeros(numUnmatchedGT,1); fnGT = find(~matchedGTs); distADDS = cat(1, distADDS, falseNegativesADDS); predIndices = cat(1, predIndices, fnPred); gtIndices = cat(1, gtIndices, fnGT); end predIndices = cat(1, predIndices{:}); gtIndices = cat(1, gtIndices{:}); end
helperDownloadPVCPartsDataset
function datasetUnzipFolder = helperDownloadPVCPartsDataset() datasetURL = "https://ssd.mathworks.com/supportfiles/vision/data/pvcparts100Dataset.zip"; datasetUnzipFolder = fullfile(tempdir,"dataset"); datasetZip = fullfile(datasetUnzipFolder,"pvcparts100Dataset.zip"); if ~exist(datasetZip,"file") mkdir(datasetUnzipFolder); disp("Downloading PVC Parts dataset (200 MB)..."); websave(datasetZip, datasetURL); end unzip(datasetZip, datasetUnzipFolder) end
References
[1] Yu Xiang, Tanner Schmidt, Venkatraman Narayanan, and Dieter Fox. "PoseCNN: A Convolutional Neural Network for 6D Object Pose Estimation in Cluttered Scenes." In Robotics: Science and Systems (RSS), 2018.
[2] Jiang, Xiaoke, Donghai Li, Hao Chen, Ye Zheng, Rui Zhao, and Liwei Wu. "Uni6d: A unified CNN framework without projection breakdown for 6d pose estimation." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022.
[3] Gunilla Borgefors. "Hierarchical Chamfer matching: A parametric edge matching algorithm." In IEEE Transactions on Pattern Analysis and Machine Intelligence 10, no. 6 (1988): 849-865.
See Also
posemaskrcnn
| predictPose
| trainMaskRCNN
| insertObjectMask
| maskrcnn
| trainingOptions
(Deep Learning Toolbox)