PrimaryKeyFilter.java

/*
 *
 * The DbUnit Database Testing Framework
 * Copyright (C)2002-2005, DbUnit.org
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */
package org.dbunit.database;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;

import org.dbunit.database.search.ForeignKeyRelationshipEdge;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.ITableMetaData;
import org.dbunit.dataset.filter.AbstractTableFilter;
import org.dbunit.util.SQLHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Filter a table given a map of the allowed rows based on primary key values.<br>
 * It uses a depth-first algorithm (although not recursive - it might be refactored
 * in the future) to define which rows are allowed, as well which rows are necessary
 * (and hence allowed) because of dependencies with the allowed rows.<br>
 * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
 * TODO: test cases
 * @author Felipe Leme (dbunit@felipeal.net)
 * @author Last changed by: $Author$
 * @version $Revision$ $Date$
 * @since Sep 9, 2005
 */
public class PrimaryKeyFilter extends AbstractTableFilter {

    private final IDatabaseConnection connection;

    private final PkTableMap allowedPKsPerTable;
    private final PkTableMap allowedPKsInput;
    private final PkTableMap pksToScanPerTable;

    private final boolean reverseScan;

    protected final Logger logger = LoggerFactory.getLogger(getClass());

    // cache the primary keys
    private final Map pkColumnPerTable = new HashMap();

    private final Map fkEdgesPerTable = new HashMap();
    private final Map fkReverseEdgesPerTable = new HashMap();

    // name of the tables, in reverse order of dependency
    private final List tableNames = new ArrayList();

    /**
     * Default constructor, it takes as input a map with desired rows in a final
     * dataset; the filter will ensure that the rows necessary by these initial rows
     * are also allowed (and so on...).
     * @param connection database connection
     * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
     * of a table; value is a Set with allowed primary keys for that table)
     * @param reverseDependency flag indicating if the rows that depend on a row should
     * also be allowed by the filter
     */
    public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
        this.connection = connection;    
        this.allowedPKsPerTable = new PkTableMap();    
        this.allowedPKsInput = allowedPKs;
        this.reverseScan = reverseDependency;

        // we need a deep copy here
        this.pksToScanPerTable = new PkTableMap(allowedPKs);
    }

    public void nodeAdded(Object node) {
        this.tableNames.add( node );
        if ( this.logger.isDebugEnabled() ) {
            this.logger.debug("nodeAdded: " + node );
        }
    }

    public void edgeAdded(ForeignKeyRelationshipEdge edge) {
        if ( this.logger.isDebugEnabled() ) {
            this.logger.debug("edgeAdded: " + edge );
        }
        // first add it to the "direct edges"
        String from = (String) edge.getFrom();
        Set edges = (Set) this.fkEdgesPerTable.get(from);
        if ( edges == null ) {
            edges = new HashSet();
            this.fkEdgesPerTable.put( from, edges );
        }
        if ( ! edges.contains(edge) ) {
            edges.add(edge);
        }

        // then add it to the "reverse edges"
        String to = (String) edge.getTo();
        edges = (Set) this.fkReverseEdgesPerTable.get(to);
        if ( edges == null ) {
            edges = new HashSet();
            this.fkReverseEdgesPerTable.put( to, edges );
        }
        if ( ! edges.contains(edge) ) {
            edges.add(edge);
        }

        // finally, update the PKs cache
        updatePkCache(to, edge);

    }

    /**
     * @see AbstractTableFilter
     */
    public boolean isValidName(String tableName) throws DataSetException {
        //    boolean isValid = this.allowedIds.containsKey(tableName);
        //    return isValid;
        return true;
    }

    public ITableIterator iterator(IDataSet dataSet, boolean reversed)
    throws DataSetException {
        if ( this.logger.isDebugEnabled() ) {
            this.logger.debug("Filter.iterator()" );
        }
        try {
            searchPKs(dataSet);
        } catch (SQLException e) {
            throw new DataSetException( e );
        }
        return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
                .iterator());
    }

    private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
        logger.debug("searchPKs(dataSet={}) - start", dataSet);

        int counter = 0;
        while ( !this.pksToScanPerTable.isEmpty() ) {
            counter ++;
            if ( this.logger.isDebugEnabled() ) {
                this.logger.debug( "RUN # " + counter );
            }

            for( int i=this.tableNames.size()-1; i>=0; i-- ) {
                String tableName = (String) this.tableNames.get(i);
                // TODO: support multi-column PKs
                String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
                Set tmpSet = this.pksToScanPerTable.get( tableName );
                if ( tmpSet != null && ! tmpSet.isEmpty() ) {
                    Set pksToScan = new HashSet( tmpSet );
                    if ( this.logger.isDebugEnabled() ) {
                        this.logger.debug(  "before search: "+ tableName + "=>" + pksToScan );
                    }
                    scanPKs( tableName, pkColumn, pksToScan );
                    scanReversePKs( tableName, pksToScan );
                    allowPKs( tableName, pksToScan );
                    removePKsToScan( tableName, pksToScan );
                } // if
            } // for 
            removeScannedTables();
        } // while
        if ( this.logger.isDebugEnabled() ) {
            this.logger.debug( "Finished searchIds()" );
        }
    } 

    private void removeScannedTables() {
        logger.debug("removeScannedTables() - start");
        this.pksToScanPerTable.retainOnly(this.tableNames);
    }

    private void allowPKs(String table, Set newAllowedPKs) {
        logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);

        // then, add the new IDs, but checking if it should be allowed to add them
        Set forcedAllowedPKs = this.allowedPKsInput.get( table );
        if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
            allowedPKsPerTable.addAll(table, newAllowedPKs );
        } else {
            for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
                Object id = iterator.next();
                if( forcedAllowedPKs.contains(id) ) {
                    allowedPKsPerTable.add(table, id);
                } 
                else 
                {
                    if ( this.logger.isDebugEnabled() ) {
                        this.logger.debug( "Discarding id " + id + " of table " + table + 
                        " as it was not included in the input!" );
                    }
                }
            }
        }
    }

    private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
        if (logger.isDebugEnabled())
        {
            logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
                    new Object[]{ table, pkColumn, allowedIds });
        }

        Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
        if ( fkEdges == null || fkEdges.isEmpty() ) {
            return;
        }
        // we need a temporary list as there is no warranty about the set order...
        List fkTables = new ArrayList( fkEdges.size() );
        StringBuffer colsBuffer = new StringBuffer();
        for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
            ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
            fkTables.add( edge.getTo() );
            colsBuffer.append( edge.getFKColumn() );
            if ( iterator.hasNext() ) {
                colsBuffer.append( ", " );
            }
        }
        // NOTE: make sure the query below is compatible standard SQL
        String sql = "SELECT " + colsBuffer + " FROM " + table + 
        " WHERE " + pkColumn + " = ? ";
        if ( this.logger.isDebugEnabled() ) {
            this.logger.debug( "SQL: " + sql );
        }

        scanPKs(table, sql, allowedIds, fkTables);
    }

    private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
    {
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = this.connection.getConnection().prepareStatement( sql );
            for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
                Object pk = iterator.next(); // id being scanned
                if( this.logger.isDebugEnabled() ) {
                    this.logger.debug("Executing sql for ? = " + pk );
                }
                pstmt.setObject( 1, pk );
                rs = pstmt.executeQuery();
                while( rs.next() ) {
                    for( int i=0; i<fkTables.size(); i++ ) {
                        String newTable = (String) fkTables.get(i);
                        Object fk = rs.getObject(i+1);
                        if( fk != null ) {
                            if( this.logger.isDebugEnabled() ) {
                                this.logger.debug("New ID: " + newTable + "->" + fk);
                            }
                            addPKToScan( newTable, fk );
                        } 
                        else {
                            this.logger.warn( "Found null FK for relationship  " + 
                                    table + "=>" + newTable );
                        }
                    }
                }
            }
        } catch (SQLException e) {
            logger.error("scanPKs()", e);
        }
        finally {
            // new in the finally block. has been in the catch only before
            SQLHelper.close( rs, pstmt );
        }
    }

    private void scanReversePKs(String table, Set pksToScan) throws SQLException {
        logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);

        if ( ! this.reverseScan ) {
            return; 
        }
        Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
        if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
            return;
        }
        Iterator iterator = fkReverseEdges.iterator();
        while ( iterator.hasNext() ) {
            ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
            addReverseEdge( edge, pksToScan );
        }
    }

    private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
        logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);

        String fkTable = (String) edge.getFrom();
        String fkColumn = edge.getFKColumn();
        String pkColumn = getPKColumn( fkTable );
        // NOTE: make sure the query below is compatible standard SQL
        String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";

        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            if ( this.logger.isDebugEnabled() ) {
                this.logger.debug( "Preparing SQL query '" + sql + "'" );
            }
            pstmt = this.connection.getConnection().prepareStatement( sql );
            for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
                Object pk = iterator.next();
                if ( this.logger.isDebugEnabled() ) {
                    this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
                }
                pstmt.setObject( 1, pk );
                rs = pstmt.executeQuery();
                while( rs.next() ) {
                    Object fk = rs.getObject(1);
                    addPKToScan( fkTable, fk );
                }
            } 
        } finally {
            SQLHelper.close( rs, pstmt );
        }
    }

    private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
        logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);

        Object pkTo = this.pkColumnPerTable.get(table);
        if ( pkTo == null ) {
            String pkColumn = edge.getPKColumn();
            this.pkColumnPerTable.put( table, pkColumn );
        }
    }

    // TODO: support PKs with multiple values
    private String getPKColumn( String table ) throws SQLException {
        logger.debug("getPKColumn(table={}) - start", table);

        // Try to get the cached column
        String pkColumn = (String) this.pkColumnPerTable.get( table );
        if ( pkColumn == null ) {
            // If the column has not been cached until now retrieve it from the database connection
            pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
            this.pkColumnPerTable.put( table, pkColumn );
        }
        return pkColumn;
    }


    private void removePKsToScan(String table, Set ids) {
        logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);

        Set pksToScan = this.pksToScanPerTable.get(table);
        if ( pksToScan != null ) {
            if ( pksToScan == ids ) {   
                throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
            } else {
                pksToScan.removeAll( ids );
            }
        }    
    }

    private void addPKToScan(String table, Object pk) {
        logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);

        // first, check if it wasn't added yet
        if(this.allowedPKsPerTable.contains(table, pk)) {
            if ( this.logger.isDebugEnabled() ) {
                this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
            }
            return;
        }

        this.pksToScanPerTable.add(table, pk);
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append("tableNames=").append(tableNames);
        sb.append(", allowedPKsInput=").append(allowedPKsInput);
        sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
        sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
        sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
        sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
        sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
        sb.append(", reverseScan=").append(reverseScan);
        sb.append(", connection=").append(connection);
        return sb.toString();
    }


    private class FilterIterator implements ITableIterator {

        private final ITableIterator _iterator;

        public FilterIterator(ITableIterator iterator) {

            _iterator = iterator;
        }

        ////////////////////////////////////////////////////////////////////////////
        // ITableIterator interface

        public boolean next() throws DataSetException {
            if ( logger.isDebugEnabled() ) {
                logger.debug("Iterator.next()" );
            }      
            while (_iterator.next()) {
                if (accept(_iterator.getTableMetaData().getTableName())) {
                    return true;
                }
            }
            return false;
        }

        public ITableMetaData getTableMetaData() throws DataSetException {
            if ( logger.isDebugEnabled() ) {
                logger.debug("Iterator.getTableMetaData()" );
            }      
            return _iterator.getTableMetaData();
        }

        public ITable getTable() throws DataSetException {
            if ( logger.isDebugEnabled() ) {
                logger.debug("Iterator.getTable()" );
            }
            ITable table = _iterator.getTable();
            String tableName = table.getTableMetaData().getTableName();
            Set allowedPKs = allowedPKsPerTable.get( tableName );
            if ( allowedPKs != null ) {
                return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
            }
            return table;
        }
    }

    /**
     * Map that associates a table with a set of primary key objects.
     * 
     * @author gommma (gommma AT users.sourceforge.net)
     * @author Last changed by: $Author$
     * @version $Revision$ $Date$
     * @since 2.3.0
     */
    public static class PkTableMap
    {
        private final LinkedHashMap pksPerTable;
        private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);

        public PkTableMap()
        {
            this.pksPerTable = new LinkedHashMap();
        }

        /**
         * Copy constructor
         * @param allowedPKs
         */
        public PkTableMap(PkTableMap allowedPKs) {
            this.pksPerTable = new LinkedHashMap();
            Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
            while ( iterator.hasNext() ) {
                Map.Entry entry = (Map.Entry) iterator.next();
                String table = (String)entry.getKey();
                SortedSet pkObjectSet = (SortedSet) entry.getValue();
                SortedSet newSet = new TreeSet( pkObjectSet );
                this.pksPerTable.put( table, newSet );
            }
        }

        public int size() {
            return pksPerTable.size();
        }

        public boolean isEmpty() {
            return pksPerTable.isEmpty();
        }

        public boolean contains(String table, Object pkObject) {
            Set pksPerTable = this.get(table);
            return (pksPerTable != null && pksPerTable.contains(pkObject));
        }

        public void remove(String tableName) {
            this.pksPerTable.remove(tableName);
        }

        public void put(String table, SortedSet pkObjects) {
            this.pksPerTable.put(table, pkObjects);
        }

        public void add(String tableName, Object pkObject) {
            Set pksPerTable = getCreateIfNeeded(tableName);
            pksPerTable.add(pkObject);
        }

        public void addAll(String tableName, Set pkObjectsToAdd) {
            Set pksPerTable = this.getCreateIfNeeded(tableName);
            pksPerTable.addAll(pkObjectsToAdd);
        }

        public SortedSet get(String tableName) {
            return (SortedSet) this.pksPerTable.get(tableName);
        }

        private SortedSet getCreateIfNeeded(String tableName){
            SortedSet pksPerTable = this.get(tableName);
            // Lazily create the set if it did not exist yet
            if( pksPerTable == null ) {
                pksPerTable = new TreeSet();
                this.pksPerTable.put(tableName, pksPerTable);
            }
            return pksPerTable;
        }

        public String[] getTableNames() {
            return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
        }

        public void retainOnly(List tableNames) {

            List tablesToRemove = new ArrayList();
            for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
                Map.Entry entry = (Map.Entry) iterator.next();
                String table = (String) entry.getKey();
                SortedSet pksToScan = (SortedSet) entry.getValue();
                boolean removeIt = pksToScan.isEmpty();

                if ( ! tableNames.contains(table) ) {
                    if ( this.logger.isWarnEnabled() ) {
                        this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
                        "as this table has not been passed as input" );
                    }
                    removeIt = true;
                }
                if ( removeIt ) {
                    tablesToRemove.add( table );
                }
            }

            for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
                this.remove( (String)iterator.next() );
            }
        }
        
        
        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append("pKsPerTable=").append(pksPerTable);
            return sb.toString();
        }

    }
}