PrimaryKeyFilter.java

  1. /*
  2.  *
  3.  * The DbUnit Database Testing Framework
  4.  * Copyright (C)2002-2005, DbUnit.org
  5.  *
  6.  * This library is free software; you can redistribute it and/or
  7.  * modify it under the terms of the GNU Lesser General Public
  8.  * License as published by the Free Software Foundation; either
  9.  * version 2.1 of the License, or (at your option) any later version.
  10.  *
  11.  * This library is distributed in the hope that it will be useful,
  12.  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13.  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  14.  * Lesser General Public License for more details.
  15.  *
  16.  * You should have received a copy of the GNU Lesser General Public
  17.  * License along with this library; if not, write to the Free Software
  18.  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
  19.  *
  20.  */
  21. package org.dbunit.database;

  22. import java.sql.PreparedStatement;
  23. import java.sql.ResultSet;
  24. import java.sql.SQLException;
  25. import java.util.*;

  26. import org.dbunit.database.search.ForeignKeyRelationshipEdge;
  27. import org.dbunit.dataset.DataSetException;
  28. import org.dbunit.dataset.IDataSet;
  29. import org.dbunit.dataset.ITable;
  30. import org.dbunit.dataset.ITableIterator;
  31. import org.dbunit.dataset.ITableMetaData;
  32. import org.dbunit.dataset.filter.AbstractTableFilter;
  33. import org.dbunit.util.SQLHelper;
  34. import org.slf4j.Logger;
  35. import org.slf4j.LoggerFactory;

  36. /**
  37.  * Filter a table given a map of the allowed rows based on primary key values.<br>
  38.  * It uses a depth-first algorithm (although not recursive - it might be refactored
  39.  * in the future) to define which rows are allowed, as well which rows are necessary
  40.  * (and hence allowed) because of dependencies with the allowed rows.<br>
  41.  * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
  42.  * TODO: test cases
  43.  * @author Felipe Leme (dbunit@felipeal.net)
  44.  * @author Last changed by: $Author$
  45.  * @version $Revision$ $Date$
  46.  * @since Sep 9, 2005
  47.  */
  48. public class PrimaryKeyFilter extends AbstractTableFilter {

  49.     private final IDatabaseConnection connection;

  50.     private final PkTableMap allowedPKsPerTable;
  51.     private final PkTableMap allowedPKsInput;
  52.     private final PkTableMap pksToScanPerTable;

  53.     private final boolean reverseScan;

  54.     protected final Logger logger = LoggerFactory.getLogger(getClass());

  55.     // cache the primary keys
  56.     private final Map pkColumnPerTable = new HashMap();

  57.     private final Map fkEdgesPerTable = new HashMap();
  58.     private final Map fkReverseEdgesPerTable = new HashMap();

  59.     // name of the tables, in reverse order of dependency
  60.     private final List tableNames = new ArrayList();

  61.     /**
  62.      * Default constructor, it takes as input a map with desired rows in a final
  63.      * dataset; the filter will ensure that the rows necessary by these initial rows
  64.      * are also allowed (and so on...).
  65.      * @param connection database connection
  66.      * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
  67.      * of a table; value is a Set with allowed primary keys for that table)
  68.      * @param reverseDependency flag indicating if the rows that depend on a row should
  69.      * also be allowed by the filter
  70.      */
  71.     public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
  72.         this.connection = connection;    
  73.         this.allowedPKsPerTable = new PkTableMap();    
  74.         this.allowedPKsInput = allowedPKs;
  75.         this.reverseScan = reverseDependency;

  76.         // we need a deep copy here
  77.         this.pksToScanPerTable = new PkTableMap(allowedPKs);
  78.     }

  79.     public void nodeAdded(Object node) {
  80.         this.tableNames.add( node );
  81.         if ( this.logger.isDebugEnabled() ) {
  82.             this.logger.debug("nodeAdded: " + node );
  83.         }
  84.     }

  85.     public void edgeAdded(ForeignKeyRelationshipEdge edge) {
  86.         if ( this.logger.isDebugEnabled() ) {
  87.             this.logger.debug("edgeAdded: " + edge );
  88.         }
  89.         // first add it to the "direct edges"
  90.         String from = (String) edge.getFrom();
  91.         Set edges = (Set) this.fkEdgesPerTable.get(from);
  92.         if ( edges == null ) {
  93.             edges = new HashSet();
  94.             this.fkEdgesPerTable.put( from, edges );
  95.         }
  96.         if ( ! edges.contains(edge) ) {
  97.             edges.add(edge);
  98.         }

  99.         // then add it to the "reverse edges"
  100.         String to = (String) edge.getTo();
  101.         edges = (Set) this.fkReverseEdgesPerTable.get(to);
  102.         if ( edges == null ) {
  103.             edges = new HashSet();
  104.             this.fkReverseEdgesPerTable.put(to, edges);
  105.         }
  106.         if ( ! edges.contains(edge) ) {
  107.             edges.add(edge);
  108.         }

  109.         // finally, update the PKs cache
  110.         updatePkCache(to, edge);

  111.     }

  112.     /**
  113.      * @see AbstractTableFilter
  114.      */
  115.     public boolean isValidName(String tableName) throws DataSetException {
  116.         //    boolean isValid = this.allowedIds.containsKey(tableName);
  117.         //    return isValid;
  118.         return true;
  119.     }

  120.     public ITableIterator iterator(IDataSet dataSet, boolean reversed)
  121.     throws DataSetException {
  122.         if ( this.logger.isDebugEnabled() ) {
  123.             this.logger.debug("Filter.iterator()" );
  124.         }
  125.         try {
  126.             searchPKs(dataSet);
  127.         } catch (SQLException e) {
  128.             throw new DataSetException( e );
  129.         }
  130.         return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
  131.                 .iterator());
  132.     }

  133.     private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
  134.         logger.debug("searchPKs(dataSet={}) - start", dataSet);

  135.         int counter = 0;
  136.         while ( !this.pksToScanPerTable.isEmpty() ) {
  137.             counter ++;
  138.             if ( this.logger.isDebugEnabled() ) {
  139.                 this.logger.debug( "RUN # " + counter );
  140.             }

  141.             for( int i=this.tableNames.size()-1; i>=0; i-- ) {
  142.                 String tableName = (String) this.tableNames.get(i);
  143.                 // TODO: support multi-column PKs
  144.                 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
  145.                 Set tmpSet = this.pksToScanPerTable.get( tableName );
  146.                 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
  147.                     Set pksToScan = new HashSet( tmpSet );
  148.                     if ( this.logger.isDebugEnabled() ) {
  149.                         this.logger.debug(  "before search: "+ tableName + "=>" + pksToScan );
  150.                     }
  151.                     scanPKs( tableName, pkColumn, pksToScan );
  152.                     scanReversePKs( tableName, pksToScan );
  153.                     allowPKs( tableName, pksToScan );
  154.                     removePKsToScan( tableName, pksToScan );
  155.                 } // if
  156.             } // for
  157.             removeScannedTables();
  158.         } // while
  159.         if ( this.logger.isDebugEnabled() ) {
  160.             this.logger.debug( "Finished searchIds()" );
  161.         }
  162.     }

  163.     private void removeScannedTables() {
  164.         logger.debug("removeScannedTables() - start");
  165.         this.pksToScanPerTable.retainOnly(this.tableNames);
  166.     }

  167.     private void allowPKs(String table, Set newAllowedPKs) {
  168.         logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);

  169.         // then, add the new IDs, but checking if it should be allowed to add them
  170.         Set forcedAllowedPKs = this.allowedPKsInput.get( table );
  171.         if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
  172.             allowedPKsPerTable.addAll(table, newAllowedPKs );
  173.         } else {
  174.             for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
  175.                 Object id = iterator.next();
  176.                 if( forcedAllowedPKs.contains(id) ) {
  177.                     allowedPKsPerTable.add(table, id);
  178.                 }
  179.                 else
  180.                 {
  181.                     logger.debug("Discarding id {} of table {} as it was not included in the input!", id, table);
  182.                 }
  183.             }
  184.         }
  185.     }

  186.     private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
  187.         if (logger.isDebugEnabled())
  188.         {
  189.             logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
  190.                     new Object[]{ table, pkColumn, allowedIds });
  191.         }

  192.         Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
  193.         if ( fkEdges == null || fkEdges.isEmpty() ) {
  194.             return;
  195.         }
  196.         // we need a temporary list as there is no warranty about the set order...
  197.         List fkTables = new ArrayList( fkEdges.size() );
  198.         final StringBuilder colsBuffer = new StringBuilder();
  199.         for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
  200.             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
  201.             fkTables.add( edge.getTo() );
  202.             colsBuffer.append( edge.getFKColumn() );
  203.             if ( iterator.hasNext() ) {
  204.                 colsBuffer.append( ", " );
  205.             }
  206.         }
  207.         // NOTE: make sure the query below is compatible standard SQL
  208.         String sql = "SELECT " + colsBuffer + " FROM " + table +
  209.         " WHERE " + pkColumn + " = ? ";
  210.         if ( this.logger.isDebugEnabled() ) {
  211.             this.logger.debug( "SQL: " + sql );
  212.         }

  213.         scanPKs(table, sql, allowedIds, fkTables);
  214.     }

  215.     private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
  216.     {
  217.         PreparedStatement pstmt = null;
  218.         ResultSet rs = null;
  219.         try {
  220.             pstmt = this.connection.getConnection().prepareStatement( sql );
  221.             for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
  222.                 Object pk = iterator.next(); // id being scanned
  223.                 if( this.logger.isDebugEnabled() ) {
  224.                     this.logger.debug("Executing sql for ? = " + pk );
  225.                 }
  226.                 pstmt.setObject( 1, pk );
  227.                 rs = pstmt.executeQuery();
  228.                 while( rs.next() ) {
  229.                     for( int i=0; i<fkTables.size(); i++ ) {
  230.                         String newTable = (String) fkTables.get(i);
  231.                         Object fk = rs.getObject(i+1);
  232.                         if( fk != null ) {
  233.                             logger.debug("New ID: {}->{}", newTable, fk);
  234.                             addPKToScan( newTable, fk );
  235.                         }
  236.                         else {
  237.                             logger.warn( "Found null FK for relationship {} =>{}", table, newTable );
  238.                         }
  239.                     }
  240.                 }
  241.             }
  242.         } catch (SQLException e) {
  243.             logger.error("scanPKs()", e);
  244.         }
  245.         finally {
  246.             // new in the finally block. has been in the catch only before
  247.             SQLHelper.close( rs, pstmt );
  248.         }
  249.     }

  250.     private void scanReversePKs(String table, Set pksToScan) throws SQLException {
  251.         logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);

  252.         if ( ! this.reverseScan ) {
  253.             return;
  254.         }
  255.         Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
  256.         if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
  257.             return;
  258.         }
  259.         Iterator iterator = fkReverseEdges.iterator();
  260.         while ( iterator.hasNext() ) {
  261.             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
  262.             addReverseEdge( edge, pksToScan );
  263.         }
  264.     }

  265.     private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
  266.         logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);

  267.         String fkTable = (String) edge.getFrom();
  268.         String fkColumn = edge.getFKColumn();
  269.         String pkColumn = getPKColumn( fkTable );
  270.         // NOTE: make sure the query below is compatible standard SQL
  271.         String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";

  272.         PreparedStatement pstmt = null;
  273.         ResultSet rs = null;
  274.         try {
  275.             logger.debug("Preparing SQL query '{}'", sql);

  276.             pstmt = this.connection.getConnection().prepareStatement(sql);
  277.             for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
  278.                 Object pk = iterator.next();
  279.                 if ( this.logger.isDebugEnabled() ) {
  280.                     this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
  281.                 }
  282.                 pstmt.setObject( 1, pk );
  283.                 rs = pstmt.executeQuery();
  284.                 while( rs.next() ) {
  285.                     Object fk = rs.getObject(1);
  286.                     addPKToScan( fkTable, fk );
  287.                 }
  288.             }
  289.         } finally {
  290.             SQLHelper.close( rs, pstmt );
  291.         }
  292.     }

  293.     private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
  294.         logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);

  295.         Object pkTo = this.pkColumnPerTable.get(table);
  296.         if ( pkTo == null ) {
  297.             String pkColumn = edge.getPKColumn();
  298.             this.pkColumnPerTable.put( table, pkColumn );
  299.         }
  300.     }

  301.     // TODO: support PKs with multiple values
  302.     private String getPKColumn( String table ) throws SQLException {
  303.         logger.debug("getPKColumn(table={}) - start", table);

  304.         // Try to get the cached column
  305.         String pkColumn = (String) this.pkColumnPerTable.get( table );
  306.         if ( pkColumn == null ) {
  307.             // If the column has not been cached until now retrieve it from the database connection
  308.             pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
  309.             this.pkColumnPerTable.put( table, pkColumn );
  310.         }
  311.         return pkColumn;
  312.     }


  313.     private void removePKsToScan(String table, Set ids) {
  314.         logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);

  315.         Set pksToScan = this.pksToScanPerTable.get(table);
  316.         if ( pksToScan != null ) {
  317.             if ( pksToScan == ids ) {  
  318.                 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
  319.             } else {
  320.                 pksToScan.removeAll( ids );
  321.             }
  322.         }    
  323.     }

  324.     private void addPKToScan(String table, Object pk) {
  325.         logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);

  326.         // first, check if it wasn't added yet
  327.         if(this.allowedPKsPerTable.contains(table, pk)) {
  328.             if ( this.logger.isDebugEnabled() ) {
  329.                 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
  330.             }
  331.             return;
  332.         }

  333.         this.pksToScanPerTable.add(table, pk);
  334.     }

  335.     public String toString() {
  336.         final StringBuilder sb = new StringBuilder();
  337.         sb.append("tableNames=").append(tableNames);
  338.         sb.append(", allowedPKsInput=").append(allowedPKsInput);
  339.         sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
  340.         sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
  341.         sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
  342.         sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
  343.         sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
  344.         sb.append(", reverseScan=").append(reverseScan);
  345.         sb.append(", connection=").append(connection);
  346.         return sb.toString();
  347.     }


  348.     private class FilterIterator implements ITableIterator {

  349.         private final ITableIterator _iterator;

  350.         public FilterIterator(ITableIterator iterator) {

  351.             _iterator = iterator;
  352.         }

  353.         ////////////////////////////////////////////////////////////////////////////
  354.         // ITableIterator interface

  355.         public boolean next() throws DataSetException {
  356.             if ( logger.isDebugEnabled() ) {
  357.                 logger.debug("Iterator.next()" );
  358.             }      
  359.             while (_iterator.next()) {
  360.                 if (accept(_iterator.getTableMetaData().getTableName())) {
  361.                     return true;
  362.                 }
  363.             }
  364.             return false;
  365.         }

  366.         public ITableMetaData getTableMetaData() throws DataSetException {
  367.             if ( logger.isDebugEnabled() ) {
  368.                 logger.debug("Iterator.getTableMetaData()" );
  369.             }      
  370.             return _iterator.getTableMetaData();
  371.         }

  372.         public ITable getTable() throws DataSetException {
  373.             if ( logger.isDebugEnabled() ) {
  374.                 logger.debug("Iterator.getTable()" );
  375.             }
  376.             ITable table = _iterator.getTable();
  377.             String tableName = table.getTableMetaData().getTableName();
  378.             Set allowedPKs = allowedPKsPerTable.get( tableName );
  379.             if ( allowedPKs != null ) {
  380.                 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
  381.             }
  382.             return table;
  383.         }
  384.     }

  385.     /**
  386.      * Map that associates a table with a set of primary key objects.
  387.      *
  388.      * @author gommma (gommma AT users.sourceforge.net)
  389.      * @author Last changed by: $Author$
  390.      * @version $Revision$ $Date$
  391.      * @since 2.3.0
  392.      */
  393.     public static class PkTableMap
  394.     {
  395.         private final LinkedHashMap pksPerTable;
  396.         private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);

  397.         public PkTableMap()
  398.         {
  399.             this.pksPerTable = new LinkedHashMap();
  400.         }

  401.         /**
  402.          * Copy constructor
  403.          * @param allowedPKs
  404.          */
  405.         public PkTableMap(PkTableMap allowedPKs) {
  406.             this.pksPerTable = new LinkedHashMap();
  407.             Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
  408.             while ( iterator.hasNext() ) {
  409.                 Map.Entry entry = (Map.Entry) iterator.next();
  410.                 String table = (String)entry.getKey();
  411.                 SortedSet pkObjectSet = (SortedSet) entry.getValue();
  412.                 SortedSet newSet = new TreeSet( pkObjectSet );
  413.                 this.pksPerTable.put( table, newSet );
  414.             }
  415.         }

  416.         public int size() {
  417.             return pksPerTable.size();
  418.         }

  419.         public boolean isEmpty() {
  420.             return pksPerTable.isEmpty();
  421.         }

  422.         public boolean contains(String table, Object pkObject) {
  423.             Set pksPerTable = this.get(table);
  424.             return (pksPerTable != null && pksPerTable.contains(pkObject));
  425.         }

  426.         public void remove(String tableName) {
  427.             this.pksPerTable.remove(tableName);
  428.         }

  429.         public void put(String table, SortedSet pkObjects) {
  430.             this.pksPerTable.put(table, pkObjects);
  431.         }

  432.         public void add(String tableName, Object pkObject) {
  433.             Set pksPerTable = getCreateIfNeeded(tableName);
  434.             pksPerTable.add(pkObject);
  435.         }

  436.         public void addAll(String tableName, Set pkObjectsToAdd) {
  437.             Set pksPerTable = this.getCreateIfNeeded(tableName);
  438.             pksPerTable.addAll(pkObjectsToAdd);
  439.         }

  440.         public SortedSet get(String tableName) {
  441.             return (SortedSet) this.pksPerTable.get(tableName);
  442.         }

  443.         private SortedSet getCreateIfNeeded(String tableName){
  444.             SortedSet pksPerTable = this.get(tableName);
  445.             // Lazily create the set if it did not exist yet
  446.             if( pksPerTable == null ) {
  447.                 pksPerTable = new TreeSet();
  448.                 this.pksPerTable.put(tableName, pksPerTable);
  449.             }
  450.             return pksPerTable;
  451.         }

  452.         public String[] getTableNames() {
  453.             return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
  454.         }

  455.         public void retainOnly(List tableNames) {

  456.             List tablesToRemove = new ArrayList();
  457.             for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
  458.                 Map.Entry entry = (Map.Entry) iterator.next();
  459.                 String table = (String) entry.getKey();
  460.                 SortedSet pksToScan = (SortedSet) entry.getValue();
  461.                 boolean removeIt = pksToScan.isEmpty();

  462.                 if ( ! tableNames.contains(table) ) {
  463.                     if ( this.logger.isWarnEnabled() ) {
  464.                         this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
  465.                         "as this table has not been passed as input" );
  466.                     }
  467.                     removeIt = true;
  468.                 }
  469.                 if ( removeIt ) {
  470.                     tablesToRemove.add( table );
  471.                 }
  472.             }

  473.             for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
  474.                 this.remove( (String)iterator.next() );
  475.             }
  476.         }
  477.        
  478.        
  479.         public String toString() {
  480.             final StringBuilder sb = new StringBuilder();
  481.             sb.append("pKsPerTable=").append(pksPerTable);
  482.             return sb.toString();
  483.         }

  484.     }
  485. }