A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

Overview

ssnt-loss

ℹ️ This is a WIP project. the implementation is still being tested.

A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.

Usage

There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.

>> target_mask = targets.ne(pad) # (B, T) >>> targets = targets[target_mask] # (T_flat,) >>> log_probs = log_probs[target_mask] # (T_flat, S, V) Args: log_probs (Tensor): Word prediction log-probs, should be output of log_softmax. tensor with shape (T_flat, S, V) where T_flat is the summation of all target lengths, S is the maximum number of input frames and V is the vocabulary of labels. targets (Tensor): Tensor with shape (T_flat,) representing the reference target labels for all samples in the minibatch. log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid. tensor with shape (T_flat, S) where T_flat is the summation of all target lengths, S is the maximum number of input frames. source_lengths (Tensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. target_lengths (Tensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. neg_inf (float, optional): The constant representing -inf used for masking. Default: -1e4 reduction (string, optional): Specifies reduction. suppoerts mean / sum. Default: None. """">
def ssnt_loss_mem(
    log_probs: Tensor,
    targets: Tensor,
    log_p_choose: Tensor,
    source_lengths: Tensor,
    target_lengths: Tensor,
    neg_inf: float = -1e4,
    reduction="mean",
):
    """The memory efficient implementation concatenates along the targets
    dimension to reduce wasted computation on padding positions.

    Assuming the summation of all targets in the batch is T_flat, then
    the original B x T x ... tensor is reduced to T_flat x ...

    The input tensors can be obtained by using target mask:
    Example:
        >>> target_mask = targets.ne(pad)   # (B, T)
        >>> targets = targets[target_mask]  # (T_flat,)
        >>> log_probs = log_probs[target_mask]  # (T_flat, S, V)

    Args:
        log_probs (Tensor): Word prediction log-probs, should be output of log_softmax.
            tensor with shape (T_flat, S, V)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames and V is
            the vocabulary of labels.
        targets (Tensor): Tensor with shape (T_flat,) representing the
            reference target labels for all samples in the minibatch.
        log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid.
            tensor with shape (T_flat, S)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames.
        source_lengths (Tensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        target_lengths (Tensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        neg_inf (float, optional): The constant representing -inf used for masking.
            Default: -1e4
        reduction (string, optional): Specifies reduction. suppoerts mean / sum.
            Default: None.
    """

Minimal example

import torch
import torch.nn as nn
import torch.nn.functional as F
from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask
B, S, H, T, V = 2, 100, 256, 10, 2000

# model
transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
joiner_trans = nn.Linear(H, V, bias=False).cuda()
joiner_alpha = nn.Sequential(
    nn.Linear(H, 1, bias=True),
    nn.Tanh()
).cuda()

# inputs
src_embed = torch.rand(B, S, H).cuda().requires_grad_()
tgt_embed = torch.rand(B, T, H).cuda().requires_grad_()
targets = torch.randint(0, V, (B, T)).cuda()
adjust = lambda x, goal: x * goal // x.max()
source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S)
target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T)

# forward
src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0))
tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0))

# memory efficient joint
mask = ~lengths_to_padding_mask(target_lengths)
lattice = F.relu(
    src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2)
)[mask]
log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1)
lattice = joiner_trans(lattice).log_softmax(-1)

# normal ssnt loss
loss = ssnt_loss_mem(
    lattice,
    targets[mask],
    log_alpha,
    source_lengths=source_lengths,
    target_lengths=target_lengths,
    reduction="sum"
) / (B*T)
loss.backward()
print(loss.item())

Note

This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized cumsum and cumprod to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.

Feel free to contact me if there are bugs in the code.

Reference

Owner
張致強
張致強
A Simple Key-Value Data-store written in Python

mercury-db This is a File Based Key-Value Datastore that supports basic CRUD (Create, Read, Update, Delete) operations developed using Python. The dat

Vaidhyanathan S M 1 Jan 09, 2022
🙄 Difficult algorithm, Simple code.

🎉TensorFlow2.0-Examples🎉! "Talk is cheap, show me the code." ----- Linus Torvalds Created by YunYang1994 This tutorial was designed for easily divin

1.7k Dec 25, 2022
Regularizing Nighttime Weirdness: Efficient Self-supervised Monocular Depth Estimation in the Dark (ICCV 2021)

Regularizing Nighttime Weirdness: Efficient Self-supervised Monocular Depth Estimation in the Dark (ICCV 2021) Kun Wang, Zhenyu Zhang, Zhiqiang Yan, X

kunwang 66 Nov 24, 2022
Codes for our paper The Stem Cell Hypothesis: Dilemma behind Multi-Task Learning with Transformer Encoders published to EMNLP 2021.

The Stem Cell Hypothesis Codes for our paper The Stem Cell Hypothesis: Dilemma behind Multi-Task Learning with Transformer Encoders published to EMNLP

Emory NLP 5 Jul 08, 2022
Code for generating a single image pretraining dataset

Single Image Pretraining of Visual Representations As shown in the paper A critical analysis of self-supervision, or what we can learn from a single i

Yuki M. Asano 12 Dec 19, 2022
Exploring Relational Context for Multi-Task Dense Prediction [ICCV 2021]

Adaptive Task-Relational Context (ATRC) This repository provides source code for the ICCV 2021 paper Exploring Relational Context for Multi-Task Dense

David Brüggemann 35 Dec 05, 2022
Self-Supervised Speech Pre-training and Representation Learning Toolkit.

What's New Sep 2021: We host a challenge in AAAI workshop: The 2nd Self-supervised Learning for Audio and Speech Processing! See SUPERB official site

s3prl 1.6k Jan 08, 2023
Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection

Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection This material is supplementray code for paper accepted in ICDAR 2021 We h

NCSOFT 30 Dec 21, 2022
A Keras implementation of CapsNet in the paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules

NOTE This implementation is fork of https://github.com/XifengGuo/CapsNet-Keras , applied to IMDB texts reviews dataset. CapsNet-Keras A Keras implemen

Lauro Moraes 5 Oct 23, 2022
NAS-FCOS: Fast Neural Architecture Search for Object Detection (CVPR 2020)

NAS-FCOS: Fast Neural Architecture Search for Object Detection This project hosts the train and inference code with pretrained model for implementing

Ning Wang 180 Dec 06, 2022
Train SN-GAN with AdaBelief

SNGAN-AdaBelief Train a state-of-the-art spectral normalization GAN with AdaBelief https://github.com/juntang-zhuang/Adabelief-Optimizer Acknowledgeme

Juntang Zhuang 10 Jun 11, 2022
MusicYOLO framework uses the object detection model, YOLOx, to locate notes in the spectrogram.

MusicYOLO MusicYOLO framework uses the object detection model, YOLOX, to locate notes in the spectrogram. Its performance on the ISMIR2014 dataset, MI

Xianke Wang 2 Aug 02, 2022
library for nonlinear optimization, wrapping many algorithms for global and local, constrained or unconstrained, optimization

NLopt is a library for nonlinear local and global optimization, for functions with and without gradient information. It is designed as a simple, unifi

Steven G. Johnson 1.4k Dec 25, 2022
[ICCV' 21] "Unsupervised Point Cloud Pre-training via Occlusion Completion"

OcCo: Unsupervised Point Cloud Pre-training via Occlusion Completion This repository is the official implementation of paper: "Unsupervised Point Clou

Hanchen 204 Dec 24, 2022
🇰🇷 Text to Image in Korean

KoDALLE Utilizing pretrained language model’s token embedding layer and position embedding layer as DALLE’s text encoder. Background Training DALLE mo

HappyFace 74 Sep 22, 2022
A python library for time-series smoothing and outlier detection in a vectorized way.

tsmoothie A python library for time-series smoothing and outlier detection in a vectorized way. Overview tsmoothie computes, in a fast and efficient w

Marco Cerliani 517 Dec 28, 2022
RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation (CIKM'17)

RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation This is the implementation of RATE: Overcoming Noise and Spar

Yu Zhang 5 Feb 10, 2022
Code associated with the paper "Deep Optics for Single-shot High-dynamic-range Imaging"

Deep Optics for Single-shot High-dynamic-range Imaging Code associated with the paper "Deep Optics for Single-shot High-dynamic-range Imaging" CVPR, 2

Stanford Computational Imaging Lab 40 Dec 12, 2022
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Jan 02, 2023
[3DV 2021] Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation

Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation This is the official implementation for the method described in Ch

Jiaxing Yan 27 Dec 30, 2022