- Rename Mastodon::TimestampIds into Mastodon::Snowflake for clarity - Skip for statuses coming from inbox, aka delivered in real-time - Skip for statuses that claim to be from the future
		
			
				
	
	
		
			163 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
			
		
		
	
	
			163 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
# frozen_string_literal: true
 | 
						|
 | 
						|
module Mastodon::Snowflake
 | 
						|
  DEFAULT_REGEX = /timestamp_id\('(?<seq_prefix>\w+)'/
 | 
						|
 | 
						|
  class Callbacks
 | 
						|
    def self.around_create(record)
 | 
						|
      now = Time.now.utc
 | 
						|
 | 
						|
      if record.created_at.nil? || record.created_at >= now || record.created_at == record.updated_at
 | 
						|
        yield
 | 
						|
      else
 | 
						|
        record.id = Mastodon::Snowflake.id_at(record.created_at)
 | 
						|
        tries     = 0
 | 
						|
 | 
						|
        begin
 | 
						|
          yield
 | 
						|
        rescue ActiveRecord::RecordNotUnique
 | 
						|
          raise if tries > 100
 | 
						|
 | 
						|
          tries     += 1
 | 
						|
          record.id += rand(100)
 | 
						|
 | 
						|
          retry
 | 
						|
        end
 | 
						|
      end
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  class << self
 | 
						|
    # Our ID will be composed of the following:
 | 
						|
    # 6 bytes (48 bits) of millisecond-level timestamp
 | 
						|
    # 2 bytes (16 bits) of sequence data
 | 
						|
    #
 | 
						|
    # The 'sequence data' is intended to be unique within a
 | 
						|
    # given millisecond, yet obscure the 'serial number' of
 | 
						|
    # this row.
 | 
						|
    #
 | 
						|
    # To do this, we hash the following data:
 | 
						|
    # * Table name (if provided, skipped if not)
 | 
						|
    # * Secret salt (should not be guessable)
 | 
						|
    # * Timestamp (again, millisecond-level granularity)
 | 
						|
    #
 | 
						|
    # We then take the first two bytes of that value, and add
 | 
						|
    # the lowest two bytes of the table ID sequence number
 | 
						|
    # (`table_name`_id_seq). This means that even if we insert
 | 
						|
    # two rows at the same millisecond, they will have
 | 
						|
    # distinct 'sequence data' portions.
 | 
						|
    #
 | 
						|
    # If this happens, and an attacker can see both such IDs,
 | 
						|
    # they can determine which of the two entries was inserted
 | 
						|
    # first, but not the total number of entries in the table
 | 
						|
    # (even mod 2**16).
 | 
						|
    #
 | 
						|
    # The table name is included in the hash to ensure that
 | 
						|
    # different tables derive separate sequence bases so rows
 | 
						|
    # inserted in the same millisecond in different tables do
 | 
						|
    # not reveal the table ID sequence number for one another.
 | 
						|
    #
 | 
						|
    # The secret salt is included in the hash to ensure that
 | 
						|
    # external users cannot derive the sequence base given the
 | 
						|
    # timestamp and table name, which would allow them to
 | 
						|
    # compute the table ID sequence number.
 | 
						|
    def define_timestamp_id
 | 
						|
      return if already_defined?
 | 
						|
 | 
						|
      connection.execute(<<~SQL)
 | 
						|
        CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
 | 
						|
        RETURNS bigint AS
 | 
						|
        $$
 | 
						|
          DECLARE
 | 
						|
            time_part bigint;
 | 
						|
            sequence_base bigint;
 | 
						|
            tail bigint;
 | 
						|
          BEGIN
 | 
						|
            time_part := (
 | 
						|
              -- Get the time in milliseconds
 | 
						|
              ((date_part('epoch', now()) * 1000))::bigint
 | 
						|
              -- And shift it over two bytes
 | 
						|
              << 16);
 | 
						|
 | 
						|
            sequence_base := (
 | 
						|
              'x' ||
 | 
						|
              -- Take the first two bytes (four hex characters)
 | 
						|
              substr(
 | 
						|
                -- Of the MD5 hash of the data we documented
 | 
						|
                md5(table_name ||
 | 
						|
                  '#{SecureRandom.hex(16)}' ||
 | 
						|
                  time_part::text
 | 
						|
                ),
 | 
						|
                1, 4
 | 
						|
              )
 | 
						|
            -- And turn it into a bigint
 | 
						|
            )::bit(16)::bigint;
 | 
						|
 | 
						|
            -- Finally, add our sequence number to our base, and chop
 | 
						|
            -- it to the last two bytes
 | 
						|
            tail := (
 | 
						|
              (sequence_base + nextval(table_name || '_id_seq'))
 | 
						|
              & 65535);
 | 
						|
 | 
						|
            -- Return the time part and the sequence part. OR appears
 | 
						|
            -- faster here than addition, but they're equivalent:
 | 
						|
            -- time_part has no trailing two bytes, and tail is only
 | 
						|
            -- the last two bytes.
 | 
						|
            RETURN time_part | tail;
 | 
						|
          END
 | 
						|
        $$ LANGUAGE plpgsql VOLATILE;
 | 
						|
      SQL
 | 
						|
    end
 | 
						|
 | 
						|
    def ensure_id_sequences_exist
 | 
						|
      # Find tables using timestamp IDs.
 | 
						|
      connection.tables.each do |table|
 | 
						|
        # We're only concerned with "id" columns.
 | 
						|
        next unless (id_col = connection.columns(table).find { |col| col.name == 'id' })
 | 
						|
 | 
						|
        # And only those that are using timestamp_id.
 | 
						|
        next unless (data = DEFAULT_REGEX.match(id_col.default_function))
 | 
						|
 | 
						|
        seq_name = data[:seq_prefix] + '_id_seq'
 | 
						|
 | 
						|
        # If we were on Postgres 9.5+, we could do CREATE SEQUENCE IF
 | 
						|
        # NOT EXISTS, but we can't depend on that. Instead, catch the
 | 
						|
        # possible exception and ignore it.
 | 
						|
        # Note that seq_name isn't a column name, but it's a
 | 
						|
        # relation, like a column, and follows the same quoting rules
 | 
						|
        # in Postgres.
 | 
						|
        connection.execute(<<~SQL)
 | 
						|
          DO $$
 | 
						|
            BEGIN
 | 
						|
              CREATE SEQUENCE #{connection.quote_column_name(seq_name)};
 | 
						|
            EXCEPTION WHEN duplicate_table THEN
 | 
						|
              -- Do nothing, we have the sequence already.
 | 
						|
            END
 | 
						|
          $$ LANGUAGE plpgsql;
 | 
						|
        SQL
 | 
						|
      end
 | 
						|
    end
 | 
						|
 | 
						|
    def id_at(timestamp)
 | 
						|
      id  = timestamp.to_i * 1000 + rand(1000)
 | 
						|
      id  = id << 16
 | 
						|
      id += rand(2**16)
 | 
						|
      id
 | 
						|
    end
 | 
						|
 | 
						|
    private
 | 
						|
 | 
						|
    def already_defined?
 | 
						|
      connection.execute(<<~SQL).values.first.first
 | 
						|
        SELECT EXISTS(
 | 
						|
          SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
 | 
						|
        );
 | 
						|
      SQL
 | 
						|
    end
 | 
						|
 | 
						|
    def connection
 | 
						|
      ActiveRecord::Base.connection
 | 
						|
    end
 | 
						|
  end
 | 
						|
end
 |