* The DbUnit Database Testing Framework
* Copyright (C)2002-2005,
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* Lesser General Public License for more details.
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
package org.dbunit.database;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.ITableMetaData;
import org.dbunit.dataset.filter.AbstractTableFilter;
import org.dbunit.util.SQLHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
* Filter a table given a map of the allowed rows based on primary key values.<br>
* It uses a depth-first algorithm (although not recursive - it might be refactored
* in the future) to define which rows are allowed, as well which rows are necessary
* (and hence allowed) because of dependencies with the allowed rows.<br>
* <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
* TODO: test cases
* @author Felipe Leme (
* @author Last changed by: $Author$
* @version $Revision$ $Date$
* @since Sep 9, 2005
public class PrimaryKeyFilter extends AbstractTableFilter {
private final IDatabaseConnection connection;
private final PkTableMap allowedPKsPerTable;
private final PkTableMap allowedPKsInput;
private final PkTableMap pksToScanPerTable;
private final boolean reverseScan;
protected final Logger logger = LoggerFactory.getLogger(getClass());
// cache the primary keys
private final Map pkColumnPerTable = new HashMap();
private final Map fkEdgesPerTable = new HashMap();
private final Map fkReverseEdgesPerTable = new HashMap();
// name of the tables, in reverse order of dependency
private final List tableNames = new ArrayList();
* Default constructor, it takes as input a map with desired rows in a final
* dataset; the filter will ensure that the rows necessary by these initial rows
* are also allowed (and so on...).
* @param connection database connection
* @param allowedPKs map of allowed rows, based on the primary keys (key is the name
* of a table; value is a Set with allowed primary keys for that table)
* @param reverseDependency flag indicating if the rows that depend on a row should
* also be allowed by the filter
public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
this.connection = connection;
this.allowedPKsPerTable = new PkTableMap();
this.allowedPKsInput = allowedPKs;
this.reverseScan = reverseDependency;
// we need a deep copy here
this.pksToScanPerTable = new PkTableMap(allowedPKs);
public void nodeAdded(Object node) {
this.tableNames.add( node );
if ( this.logger.isDebugEnabled() ) {
this.logger.debug("nodeAdded: " + node );
public void edgeAdded(ForeignKeyRelationshipEdge edge) {
if ( this.logger.isDebugEnabled() ) {
this.logger.debug("edgeAdded: " + edge );
// first add it to the "direct edges"
String from = (String) edge.getFrom();
Set edges = (Set) this.fkEdgesPerTable.get(from);
if ( edges == null ) {
edges = new HashSet();
this.fkEdgesPerTable.put( from, edges );
if ( ! edges.contains(edge) ) {
// then add it to the "reverse edges"
String to = (String) edge.getTo();
edges = (Set) this.fkReverseEdgesPerTable.get(to);
if ( edges == null ) {
edges = new HashSet();
this.fkReverseEdgesPerTable.put(to, edges);
if ( ! edges.contains(edge) ) {
// finally, update the PKs cache
updatePkCache(to, edge);
* @see AbstractTableFilter
public boolean isValidName(String tableName) throws DataSetException {
// boolean isValid = this.allowedIds.containsKey(tableName);
// return isValid;
return true;
public ITableIterator iterator(IDataSet dataSet, boolean reversed)
throws DataSetException {
if ( this.logger.isDebugEnabled() ) {
this.logger.debug("Filter.iterator()" );
try {
} catch (SQLException e) {
throw new DataSetException( e );
return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
logger.debug("searchPKs(dataSet={}) - start", dataSet);
int counter = 0;
while ( !this.pksToScanPerTable.isEmpty() ) {
counter ++;
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "RUN # " + counter );
for( int i=this.tableNames.size()-1; i>=0; i-- ) {
String tableName = (String) this.tableNames.get(i);
// TODO: support multi-column PKs
String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
Set tmpSet = this.pksToScanPerTable.get( tableName );
if ( tmpSet != null && ! tmpSet.isEmpty() ) {
Set pksToScan = new HashSet( tmpSet );
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "before search: "+ tableName + "=>" + pksToScan );
scanPKs( tableName, pkColumn, pksToScan );
scanReversePKs( tableName, pksToScan );
allowPKs( tableName, pksToScan );
removePKsToScan( tableName, pksToScan );
} // if
} // for
} // while
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "Finished searchIds()" );
private void removeScannedTables() {
logger.debug("removeScannedTables() - start");
private void allowPKs(String table, Set newAllowedPKs) {
logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
// then, add the new IDs, but checking if it should be allowed to add them
Set forcedAllowedPKs = this.allowedPKsInput.get( table );
if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
allowedPKsPerTable.addAll(table, newAllowedPKs );
} else {
for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
Object id =;
if( forcedAllowedPKs.contains(id) ) {
allowedPKsPerTable.add(table, id);
logger.debug("Discarding id {} of table {} as it was not included in the input!", id, table);
private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
if (logger.isDebugEnabled())
logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
new Object[]{ table, pkColumn, allowedIds });
Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
if ( fkEdges == null || fkEdges.isEmpty() ) {
// we need a temporary list as there is no warranty about the set order...
List fkTables = new ArrayList( fkEdges.size() );
final StringBuilder colsBuffer = new StringBuilder();
for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge);
fkTables.add( edge.getTo() );
colsBuffer.append( edge.getFKColumn() );
if ( iterator.hasNext() ) {
colsBuffer.append( ", " );
// NOTE: make sure the query below is compatible standard SQL
String sql = "SELECT " + colsBuffer + " FROM " + table +
" WHERE " + pkColumn + " = ? ";
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "SQL: " + sql );
scanPKs(table, sql, allowedIds, fkTables);
private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = this.connection.getConnection().prepareStatement( sql );
for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
Object pk =; // id being scanned
if( this.logger.isDebugEnabled() ) {
this.logger.debug("Executing sql for ? = " + pk );
pstmt.setObject( 1, pk );
rs = pstmt.executeQuery();
while( ) {
for( int i=0; i<fkTables.size(); i++ ) {
String newTable = (String) fkTables.get(i);
Object fk = rs.getObject(i+1);
if( fk != null ) {
logger.debug("New ID: {}->{}", newTable, fk);
addPKToScan( newTable, fk );
else {
logger.warn( "Found null FK for relationship {} =>{}", table, newTable );
} catch (SQLException e) {
logger.error("scanPKs()", e);
finally {
// new in the finally block. has been in the catch only before
SQLHelper.close( rs, pstmt );
private void scanReversePKs(String table, Set pksToScan) throws SQLException {
logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
if ( ! this.reverseScan ) {
Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
Iterator iterator = fkReverseEdges.iterator();
while ( iterator.hasNext() ) {
ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge);
addReverseEdge( edge, pksToScan );
private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
String fkTable = (String) edge.getFrom();
String fkColumn = edge.getFKColumn();
String pkColumn = getPKColumn( fkTable );
// NOTE: make sure the query below is compatible standard SQL
String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
logger.debug("Preparing SQL query '{}'", sql);
pstmt = this.connection.getConnection().prepareStatement(sql);
for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
Object pk =;
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
pstmt.setObject( 1, pk );
rs = pstmt.executeQuery();
while( ) {
Object fk = rs.getObject(1);
addPKToScan( fkTable, fk );
} finally {
SQLHelper.close( rs, pstmt );
private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
Object pkTo = this.pkColumnPerTable.get(table);
if ( pkTo == null ) {
String pkColumn = edge.getPKColumn();
this.pkColumnPerTable.put( table, pkColumn );
// TODO: support PKs with multiple values
private String getPKColumn( String table ) throws SQLException {
logger.debug("getPKColumn(table={}) - start", table);
// Try to get the cached column
String pkColumn = (String) this.pkColumnPerTable.get( table );
if ( pkColumn == null ) {
// If the column has not been cached until now retrieve it from the database connection
pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
this.pkColumnPerTable.put( table, pkColumn );
return pkColumn;
private void removePKsToScan(String table, Set ids) {
logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
Set pksToScan = this.pksToScanPerTable.get(table);
if ( pksToScan != null ) {
if ( pksToScan == ids ) {
throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
} else {
pksToScan.removeAll( ids );
private void addPKToScan(String table, Object pk) {
logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
// first, check if it wasn't added yet
if(this.allowedPKsPerTable.contains(table, pk)) {
if ( this.logger.isDebugEnabled() ) {
this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
this.pksToScanPerTable.add(table, pk);
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(", allowedPKsInput=").append(allowedPKsInput);
sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
sb.append(", reverseScan=").append(reverseScan);
sb.append(", connection=").append(connection);
return sb.toString();
private class FilterIterator implements ITableIterator {
private final ITableIterator _iterator;
public FilterIterator(ITableIterator iterator) {
_iterator = iterator;
// ITableIterator interface
public boolean next() throws DataSetException {
if ( logger.isDebugEnabled() ) {
logger.debug("" );
while ( {
if (accept(_iterator.getTableMetaData().getTableName())) {
return true;
return false;
public ITableMetaData getTableMetaData() throws DataSetException {
if ( logger.isDebugEnabled() ) {
logger.debug("Iterator.getTableMetaData()" );
return _iterator.getTableMetaData();
public ITable getTable() throws DataSetException {
if ( logger.isDebugEnabled() ) {
logger.debug("Iterator.getTable()" );
ITable table = _iterator.getTable();
String tableName = table.getTableMetaData().getTableName();
Set allowedPKs = allowedPKsPerTable.get( tableName );
if ( allowedPKs != null ) {
return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
return table;
* Map that associates a table with a set of primary key objects.
* @author gommma (gommma AT
* @author Last changed by: $Author$
* @version $Revision$ $Date$
* @since 2.3.0
public static class PkTableMap
private final LinkedHashMap pksPerTable;
private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
public PkTableMap()
this.pksPerTable = new LinkedHashMap();
* Copy constructor
* @param allowedPKs
public PkTableMap(PkTableMap allowedPKs) {
this.pksPerTable = new LinkedHashMap();
Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
while ( iterator.hasNext() ) {
Map.Entry entry = (Map.Entry);
String table = (String)entry.getKey();
SortedSet pkObjectSet = (SortedSet) entry.getValue();
SortedSet newSet = new TreeSet( pkObjectSet );
this.pksPerTable.put( table, newSet );
public int size() {
return pksPerTable.size();
public boolean isEmpty() {
return pksPerTable.isEmpty();
public boolean contains(String table, Object pkObject) {
Set pksPerTable = this.get(table);
return (pksPerTable != null && pksPerTable.contains(pkObject));
public void remove(String tableName) {
public void put(String table, SortedSet pkObjects) {
this.pksPerTable.put(table, pkObjects);
public void add(String tableName, Object pkObject) {
Set pksPerTable = getCreateIfNeeded(tableName);
public void addAll(String tableName, Set pkObjectsToAdd) {
Set pksPerTable = this.getCreateIfNeeded(tableName);
public SortedSet get(String tableName) {
return (SortedSet) this.pksPerTable.get(tableName);
private SortedSet getCreateIfNeeded(String tableName){
SortedSet pksPerTable = this.get(tableName);
// Lazily create the set if it did not exist yet
if( pksPerTable == null ) {
pksPerTable = new TreeSet();
this.pksPerTable.put(tableName, pksPerTable);
return pksPerTable;
public String[] getTableNames() {
return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
public void retainOnly(List tableNames) {
List tablesToRemove = new ArrayList();
for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry entry = (Map.Entry);
String table = (String) entry.getKey();
SortedSet pksToScan = (SortedSet) entry.getValue();
boolean removeIt = pksToScan.isEmpty();
if ( ! tableNames.contains(table) ) {
if ( this.logger.isWarnEnabled() ) {
this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
"as this table has not been passed as input" );
removeIt = true;
if ( removeIt ) {
tablesToRemove.add( table );
for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
this.remove( (String) );
public String toString() {
final StringBuilder sb = new StringBuilder();
return sb.toString();